fix: typing issue and race condition issue

This commit is contained in:
LoiTra 2025-07-15 11:43:21 +07:00 committed by loitragg
parent fa19f5f5ba
commit 2f04cc8a64
No known key found for this signature in database
GPG key ID: 96292BAF3E28CFF5
2 changed files with 13 additions and 4 deletions

View file

@ -16,6 +16,14 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
# Add sentence_transformers import at module level to avoid threading issues
try:
from sentence_transformers import util as sentence_transformers_util
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
sentence_transformers_util = None
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
@ -1283,7 +1291,8 @@ class RerankCompressor(BaseDocumentCompressor):
if reranking:
scores = self.reranking_function(query, documents)
else:
from sentence_transformers import util
if not SENTENCE_TRANSFORMERS_AVAILABLE:
raise ImportError("sentence_transformers is not available. Please install it to use reranking functionality.")
query_embedding = await self.embedding_function(
query, RAG_EMBEDDING_QUERY_PREFIX
@ -1291,7 +1300,7 @@ class RerankCompressor(BaseDocumentCompressor):
document_embedding = await self.embedding_function(
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
scores = sentence_transformers_util.cos_sim(query_embedding, document_embedding)[0]
if scores is not None:
docs_with_scores = list(

View file

@ -2195,6 +2195,7 @@ class QueryDocForm(BaseModel):
k_reranker: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
hybrid_bm25_weight: Optional[float] = None
@router.post("/query/doc")
@ -2239,8 +2240,7 @@ async def query_doc_handler(
form_data.hybrid_bm25_weight
if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT
),
user=user,
)
)
else:
query_embedding = await request.app.state.EMBEDDING_FUNCTION(