mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-17 14:55:23 +00:00
Correctly unloads embedding/reranker models
This commit is contained in:
parent
b3a95f40fc
commit
39fe385017
2 changed files with 40 additions and 23 deletions
|
|
@ -924,7 +924,7 @@ try:
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
|
if ENABLE_RAG_HYBRID_SEARCH and not BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||||
app.state.rf = get_rf(
|
app.state.rf = get_rf(
|
||||||
app.state.config.RAG_RERANKING_ENGINE,
|
app.state.config.RAG_RERANKING_ENGINE,
|
||||||
app.state.config.RAG_RERANKING_MODEL,
|
app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
|
@ -932,6 +932,8 @@ try:
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
app.state.rf = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error updating models: {e}")
|
log.error(f"Error updating models: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -321,6 +321,14 @@ async def update_embedding_config(
|
||||||
form_data.embedding_batch_size
|
form_data.embedding_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# unloads current embedding model and clears VRAM cache
|
||||||
|
request.app.state.ef = None
|
||||||
|
request.app.state.EMBEDDING_FUNCTION = None
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
request.app.state.ef = get_ef(
|
request.app.state.ef = get_ef(
|
||||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
|
@ -653,9 +661,6 @@ async def update_rag_config(
|
||||||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||||
)
|
)
|
||||||
# Free up memory if hybrid search is disabled
|
|
||||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
|
||||||
request.app.state.rf = None
|
|
||||||
|
|
||||||
request.app.state.config.TOP_K_RERANKER = (
|
request.app.state.config.TOP_K_RERANKER = (
|
||||||
form_data.TOP_K_RERANKER
|
form_data.TOP_K_RERANKER
|
||||||
|
|
@ -838,6 +843,16 @@ async def update_rag_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Unloading the reranker and clear VRAM memory.
|
||||||
|
if request.app.state.rf != None:
|
||||||
|
request.app.state.rf = None
|
||||||
|
request.app.state.RERANKING_FUNCTION = None
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||||
request.app.state.rf = get_rf(
|
request.app.state.rf = get_rf(
|
||||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||||
request.app.state.config.RAG_RERANKING_MODEL,
|
request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue