Chage torch import to conditional import

This commit is contained in:
Marko Henning 2025-08-21 13:19:24 +02:00
parent 6663fc3a6c
commit b3de3295d6

View file

@ -4,7 +4,6 @@ import mimetypes
import os import os
import shutil import shutil
import asyncio import asyncio
import torch
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -287,8 +286,10 @@ async def update_embedding_config(
request.app.state.EMBEDDING_FUNCTION = None request.app.state.EMBEDDING_FUNCTION = None
import gc import gc
gc.collect() gc.collect()
if torch.cuda.is_available(): if DEVICE_TYPE == 'cuda':
torch.cuda.empty_cache() import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
try: try:
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
@ -820,8 +821,10 @@ async def update_rag_config(
request.app.state.RERANKING_FUNCTION = None request.app.state.RERANKING_FUNCTION = None
import gc import gc
gc.collect() gc.collect()
if torch.cuda.is_available(): if DEVICE_TYPE == 'cuda':
torch.cuda.empty_cache() import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
request.app.state.config.RAG_RERANKING_ENGINE = ( request.app.state.config.RAG_RERANKING_ENGINE = (
form_data.RAG_RERANKING_ENGINE form_data.RAG_RERANKING_ENGINE
if form_data.RAG_RERANKING_ENGINE is not None if form_data.RAG_RERANKING_ENGINE is not None