Unloads only if internal models are used.

This commit is contained in:
Marko Henning 2025-08-21 10:49:03 +02:00
parent cd02ff2e07
commit 6663fc3a6c

View file

@ -4,7 +4,7 @@ import mimetypes
import os
import shutil
import asyncio
import torch
import uuid
from datetime import datetime
@ -281,6 +281,14 @@ async def update_embedding_config(
log.info(
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
if request.app.state.config.RAG_EMBEDDING_ENGINE == '':
# unloads current internal embedding model and clears VRAM cache
request.app.state.ef = None
request.app.state.EMBEDDING_FUNCTION = None
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
@ -321,14 +329,6 @@ async def update_embedding_config(
form_data.embedding_batch_size
)
# unloads current embedding model and clears VRAM cache
request.app.state.ef = None
request.app.state.EMBEDDING_FUNCTION = None
import gc
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
request.app.state.ef = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
@ -814,6 +814,14 @@ async def update_rag_config(
)
# Reranking settings
if request.app.state.config.RAG_RERANKING_ENGINE == '':
# Unloading the internal reranker and clear VRAM memory
request.app.state.rf = None
request.app.state.RERANKING_FUNCTION = None
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
request.app.state.config.RAG_RERANKING_ENGINE = (
form_data.RAG_RERANKING_ENGINE
if form_data.RAG_RERANKING_ENGINE is not None
@ -843,15 +851,6 @@ async def update_rag_config(
)
try:
# Unloading the reranker and clear VRAM memory.
if request.app.state.rf != None:
request.app.state.rf = None
request.app.state.RERANKING_FUNCTION = None
import gc
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_ENGINE,