mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
adding rag treshold logic
This commit is contained in:
parent
140605e660
commit
e146a5f613
4 changed files with 52 additions and 1 deletions
|
|
@ -2625,6 +2625,12 @@ RAG_FULL_CONTEXT = PersistentConfig(
|
|||
os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_TOKEN_THRESHOLD = PersistentConfig(
|
||||
"RAG_TOKEN_THRESHOLD",
|
||||
"rag.token_threshold",
|
||||
int(os.environ.get("RAG_TOKEN_THRESHOLD", "0")),
|
||||
)
|
||||
|
||||
RAG_FILE_MAX_COUNT = PersistentConfig(
|
||||
"RAG_FILE_MAX_COUNT",
|
||||
"rag.file.max_count",
|
||||
|
|
|
|||
|
|
@ -218,6 +218,7 @@ from open_webui.config import (
|
|||
RAG_TEMPLATE,
|
||||
DEFAULT_RAG_TEMPLATE,
|
||||
RAG_FULL_CONTEXT,
|
||||
RAG_TOKEN_THRESHOLD,
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
|
|
@ -840,6 +841,7 @@ app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = FILE_IMAGE_COMPRESSION_HEIGHT
|
|||
|
||||
|
||||
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
||||
app.state.config.RAG_TOKEN_THRESHOLD = RAG_TOKEN_THRESHOLD
|
||||
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = (
|
||||
|
|
|
|||
|
|
@ -1025,9 +1025,15 @@ async def get_sources_from_items(
|
|||
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
|
||||
}
|
||||
elif item.get("type") == "file":
|
||||
file_bypassed_rag = False
|
||||
if item.get("id"):
|
||||
file_object = Files.get_file_by_id(item.get("id"))
|
||||
if file_object and file_object.meta:
|
||||
file_bypassed_rag = file_object.meta.get("bypass_rag", False)
|
||||
if (
|
||||
item.get("context") == "full"
|
||||
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
or file_bypassed_rag
|
||||
):
|
||||
if item.get("file", {}).get("data", {}).get("content", ""):
|
||||
# Manual Full Mode Toggle
|
||||
|
|
|
|||
|
|
@ -440,6 +440,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||
"TOP_K": request.app.state.config.TOP_K,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"RAG_TOKEN_THRESHOLD": request.app.state.config.RAG_TOKEN_THRESHOLD,
|
||||
# Hybrid search settings
|
||||
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
"ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
|
||||
|
|
@ -614,6 +615,7 @@ class ConfigForm(BaseModel):
|
|||
TOP_K: Optional[int] = None
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
RAG_FULL_CONTEXT: Optional[bool] = None
|
||||
RAG_TOKEN_THRESHOLD: Optional[int] = None
|
||||
|
||||
# Hybrid search settings
|
||||
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
|
||||
|
|
@ -707,6 +709,11 @@ async def update_rag_config(
|
|||
if form_data.RAG_FULL_CONTEXT is not None
|
||||
else request.app.state.config.RAG_FULL_CONTEXT
|
||||
)
|
||||
request.app.state.config.RAG_TOKEN_THRESHOLD = (
|
||||
form_data.RAG_TOKEN_THRESHOLD
|
||||
if form_data.RAG_TOKEN_THRESHOLD is not None
|
||||
else request.app.state.config.RAG_TOKEN_THRESHOLD
|
||||
)
|
||||
|
||||
# Hybrid search settings
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
|
||||
|
|
@ -1591,7 +1598,37 @@ def process_file(
|
|||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
|
||||
if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||
should_bypass = request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
token_count = 0
|
||||
|
||||
if not should_bypass and request.app.state.config.RAG_TOKEN_THRESHOLD > 0:
|
||||
try:
|
||||
encoding = tiktoken.get_encoding(
|
||||
str(request.app.state.config.TIKTOKEN_ENCODING_NAME)
|
||||
)
|
||||
token_count = len(encoding.encode(text_content))
|
||||
|
||||
if token_count <= request.app.state.config.RAG_TOKEN_THRESHOLD:
|
||||
should_bypass = True
|
||||
log.info(
|
||||
f"File '{file.filename}': {token_count} tokens "
|
||||
f"(<= {request.app.state.config.RAG_TOKEN_THRESHOLD}), bypassing RAG"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"File '{file.filename}': {token_count} tokens "
|
||||
f"(> {request.app.state.config.RAG_TOKEN_THRESHOLD}), using RAG"
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(f"Error counting tokens: {e}")
|
||||
|
||||
if should_bypass:
|
||||
Files.update_file_data_by_id(file.id, {"status": "completed"})
|
||||
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{"bypass_rag": True}
|
||||
)
|
||||
Files.update_file_data_by_id(file.id, {"status": "completed"})
|
||||
return {
|
||||
"status": True,
|
||||
|
|
|
|||
Loading…
Reference in a new issue