mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
perf: 50x performance improvement for external embeddings (#19296)
* Update utils.py (#77) Co-authored-by: Claude <noreply@anthropic.com> * refactor: address code review feedback for embedding performance improvements (#92) Co-authored-by: Claude <noreply@anthropic.com> * fix: prevent sentence transformers from blocking async event loop (#95) Co-authored-by: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
bb3e222e09
commit
902c6cfbea
4 changed files with 289 additions and 118 deletions
|
|
@ -3,6 +3,8 @@ import os
|
|||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import hashlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
|
|
@ -27,6 +29,7 @@ from open_webui.models.notes import Notes
|
|||
|
||||
from open_webui.retrieval.vector.main import GetResult
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
from open_webui.utils.misc import get_message_list
|
||||
|
||||
from open_webui.retrieval.web.utils import get_web_loader
|
||||
|
|
@ -87,15 +90,16 @@ class VectorSearchRetriever(BaseRetriever):
|
|||
embedding_function: Any
|
||||
top_k: int
|
||||
|
||||
def _get_relevant_documents(
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=self.collection_name,
|
||||
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
|
||||
vectors=[embedding],
|
||||
limit=self.top_k,
|
||||
)
|
||||
|
||||
|
|
@ -186,7 +190,7 @@ def get_enriched_texts(collection_result: GetResult) -> list[str]:
|
|||
return enriched_texts
|
||||
|
||||
|
||||
def query_doc_with_hybrid_search(
|
||||
async def query_doc_with_hybrid_search(
|
||||
collection_name: str,
|
||||
collection_result: GetResult,
|
||||
query: str,
|
||||
|
|
@ -262,7 +266,7 @@ def query_doc_with_hybrid_search(
|
|||
base_compressor=compressor, base_retriever=ensemble_retriever
|
||||
)
|
||||
|
||||
result = compression_retriever.invoke(query)
|
||||
result = await compression_retriever.ainvoke(query)
|
||||
|
||||
distances = [d.metadata.get("score") for d in result]
|
||||
documents = [d.page_content for d in result]
|
||||
|
|
@ -381,7 +385,7 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
|||
return merge_get_results(results)
|
||||
|
||||
|
||||
def query_collection(
|
||||
async def query_collection(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
|
|
@ -406,7 +410,7 @@ def query_collection(
|
|||
return None, e
|
||||
|
||||
# Generate all query embeddings (in one call)
|
||||
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
log.debug(
|
||||
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
||||
)
|
||||
|
|
@ -433,7 +437,7 @@ def query_collection(
|
|||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
def query_collection_with_hybrid_search(
|
||||
async def query_collection_with_hybrid_search(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
|
|
@ -465,9 +469,9 @@ def query_collection_with_hybrid_search(
|
|||
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
|
||||
)
|
||||
|
||||
def process_query(collection_name, query):
|
||||
async def process_query(collection_name, query):
|
||||
try:
|
||||
result = query_doc_with_hybrid_search(
|
||||
result = await query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
collection_result=collection_results[collection_name],
|
||||
query=query,
|
||||
|
|
@ -493,9 +497,8 @@ def query_collection_with_hybrid_search(
|
|||
for q in queries
|
||||
]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
|
||||
task_results = [future.result() for future in future_results]
|
||||
# Run all queries in parallel using asyncio.gather
|
||||
task_results = await asyncio.gather(*[process_query(cn, q) for cn, q in tasks])
|
||||
|
||||
for result, err in task_results:
|
||||
if err is not None:
|
||||
|
|
@ -511,6 +514,146 @@ def query_collection_with_hybrid_search(
|
|||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
async def generate_openai_batch_embeddings_async(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str = "https://api.openai.com/v1",
|
||||
key: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_openai_batch_embeddings_async:model {model} batch size: {len(texts)}"
|
||||
)
|
||||
form_data = {"input": texts, "model": model}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{url}/embeddings", headers=headers, json=form_data) as r:
|
||||
r.raise_for_status()
|
||||
data = await r.json()
|
||||
if "data" in data:
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_azure_openai_batch_embeddings_async(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str,
|
||||
key: str = "",
|
||||
version: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_azure_openai_batch_embeddings_async:deployment {model} batch size: {len(texts)}"
|
||||
)
|
||||
form_data = {"input": texts}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": key,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(full_url, headers=headers, json=form_data) as r:
|
||||
r.raise_for_status()
|
||||
data = await r.json()
|
||||
if "data" in data:
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_ollama_batch_embeddings_async(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str,
|
||||
key: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_ollama_batch_embeddings_async:model {model} batch size: {len(texts)}"
|
||||
)
|
||||
form_data = {"input": texts, "model": model}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{url}/api/embed", headers=headers, json=form_data) as r:
|
||||
r.raise_for_status()
|
||||
data = await r.json()
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"]
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_multiple_async(query, prefix, user, func_async, embedding_batch_size):
|
||||
if isinstance(query, list):
|
||||
# Create batches
|
||||
batches = [
|
||||
query[i : i + embedding_batch_size]
|
||||
for i in range(0, len(query), embedding_batch_size)
|
||||
]
|
||||
|
||||
log.debug(f"generate_multiple_async: Processing {len(batches)} batches in parallel")
|
||||
|
||||
# Execute all batches in parallel
|
||||
tasks = [
|
||||
func_async(batch, prefix=prefix, user=user)
|
||||
for batch in batches
|
||||
]
|
||||
batch_results = await asyncio.gather(*tasks)
|
||||
|
||||
# Flatten results
|
||||
embeddings = []
|
||||
for batch_embeddings in batch_results:
|
||||
if isinstance(batch_embeddings, list):
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
log.debug(f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches")
|
||||
return embeddings
|
||||
else:
|
||||
return await func_async([query], prefix=prefix, user=user)
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
|
|
@ -521,40 +664,58 @@ def get_embedding_function(
|
|||
azure_api_version=None,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query, prefix=None, user=None: embedding_function.encode(
|
||||
query, **({"prompt": prefix} if prefix else {})
|
||||
).tolist()
|
||||
# Sentence transformers: CPU-bound sync operation
|
||||
# Run in thread pool to prevent blocking async event loop
|
||||
def sync_encode(query, prefix=None):
|
||||
return embedding_function.encode(
|
||||
query, **({"prompt": prefix} if prefix else {})
|
||||
).tolist()
|
||||
|
||||
async def async_wrapper(query, prefix=None, user=None):
|
||||
# Run CPU-bound operation in thread pool
|
||||
return await asyncio.to_thread(sync_encode, query, prefix)
|
||||
|
||||
return async_wrapper
|
||||
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
||||
func = lambda query, prefix=None, user=None: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
prefix=prefix,
|
||||
url=url,
|
||||
key=key,
|
||||
user=user,
|
||||
azure_api_version=azure_api_version,
|
||||
)
|
||||
# Create async function based on engine
|
||||
if embedding_engine == "openai":
|
||||
func_async = lambda texts, prefix=None, user=None: generate_openai_batch_embeddings_async(
|
||||
model=embedding_model,
|
||||
texts=texts,
|
||||
url=url,
|
||||
key=key,
|
||||
prefix=prefix,
|
||||
user=user,
|
||||
)
|
||||
elif embedding_engine == "azure_openai":
|
||||
func_async = lambda texts, prefix=None, user=None: generate_azure_openai_batch_embeddings_async(
|
||||
model=embedding_model,
|
||||
texts=texts,
|
||||
url=url,
|
||||
key=key,
|
||||
version=azure_api_version,
|
||||
prefix=prefix,
|
||||
user=user,
|
||||
)
|
||||
elif embedding_engine == "ollama":
|
||||
func_async = lambda texts, prefix=None, user=None: generate_ollama_batch_embeddings_async(
|
||||
model=embedding_model,
|
||||
texts=texts,
|
||||
url=url,
|
||||
key=key,
|
||||
prefix=prefix,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def generate_multiple(query, prefix, user, func):
|
||||
# Return async function for parallel batch processing (FastAPI-compatible)
|
||||
async def embedding_wrapper(query, prefix=None, user=None):
|
||||
if isinstance(query, list):
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
batch_embeddings = func(
|
||||
query[i : i + embedding_batch_size],
|
||||
prefix=prefix,
|
||||
user=user,
|
||||
)
|
||||
|
||||
if isinstance(batch_embeddings, list):
|
||||
embeddings.extend(batch_embeddings)
|
||||
return embeddings
|
||||
return await generate_multiple_async(query, prefix, user, func_async, embedding_batch_size)
|
||||
else:
|
||||
return func(query, prefix, user)
|
||||
result = await func_async([query], prefix=prefix, user=user)
|
||||
return result[0] if result else None
|
||||
|
||||
return lambda query, prefix=None, user=None: generate_multiple(
|
||||
query, prefix, user, func
|
||||
)
|
||||
return embedding_wrapper
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||
|
||||
|
|
@ -572,7 +733,7 @@ def get_reranking_function(reranking_engine, reranking_model, reranking_function
|
|||
)
|
||||
|
||||
|
||||
def get_sources_from_items(
|
||||
async def get_sources_from_items(
|
||||
request,
|
||||
items,
|
||||
queries,
|
||||
|
|
@ -800,7 +961,7 @@ def get_sources_from_items(
|
|||
query_result = None # Initialize to None
|
||||
if hybrid_search:
|
||||
try:
|
||||
query_result = query_collection_with_hybrid_search(
|
||||
query_result = await query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
|
|
@ -818,7 +979,7 @@ def get_sources_from_items(
|
|||
|
||||
# fallback to non-hybrid search
|
||||
if not hybrid_search and query_result is None:
|
||||
query_result = query_collection(
|
||||
query_result = await query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
|
|
@ -1056,26 +1217,25 @@ def generate_embeddings(
|
|||
else:
|
||||
text = f"{prefix}{text}"
|
||||
|
||||
# Run async embedding generation synchronously
|
||||
if engine == "ollama":
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{
|
||||
"model": model,
|
||||
"texts": text if isinstance(text, list) else [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"prefix": prefix,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
embeddings = asyncio.run(generate_ollama_batch_embeddings_async(
|
||||
model=model,
|
||||
texts=text if isinstance(text, list) else [text],
|
||||
url=url,
|
||||
key=key,
|
||||
prefix=prefix,
|
||||
user=user,
|
||||
))
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "openai":
|
||||
embeddings = generate_openai_batch_embeddings(
|
||||
embeddings = asyncio.run(generate_openai_batch_embeddings_async(
|
||||
model, text if isinstance(text, list) else [text], url, key, prefix, user
|
||||
)
|
||||
))
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "azure_openai":
|
||||
azure_api_version = kwargs.get("azure_api_version", "")
|
||||
embeddings = generate_azure_openai_batch_embeddings(
|
||||
embeddings = asyncio.run(generate_azure_openai_batch_embeddings_async(
|
||||
model,
|
||||
text if isinstance(text, list) else [text],
|
||||
url,
|
||||
|
|
@ -1083,7 +1243,7 @@ def generate_embeddings(
|
|||
azure_api_version,
|
||||
prefix,
|
||||
user,
|
||||
)
|
||||
))
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
|
|
@ -1104,7 +1264,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
|||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
async def acompress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
|
|
@ -1114,12 +1274,14 @@ class RerankCompressor(BaseDocumentCompressor):
|
|||
|
||||
scores = None
|
||||
if reranking:
|
||||
scores = self.reranking_function(query, documents)
|
||||
scores = self.reranking_function(
|
||||
[(query, doc.page_content) for doc in documents]
|
||||
)
|
||||
else:
|
||||
from sentence_transformers import util
|
||||
|
||||
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
||||
document_embedding = self.embedding_function(
|
||||
query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
||||
document_embedding = await self.embedding_function(
|
||||
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
|
||||
)
|
||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.models.memories import Memories, MemoryModel
|
||||
|
|
@ -17,7 +18,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/ef")
|
||||
async def get_embeddings(request: Request):
|
||||
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
|
||||
return {"result": await request.app.state.EMBEDDING_FUNCTION("hello world")}
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -51,15 +52,17 @@ async def add_memory(
|
|||
):
|
||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
||||
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
)
|
||||
|
||||
VECTOR_DB_CLIENT.upsert(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
items=[
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
),
|
||||
"vector": vector,
|
||||
"metadata": {"created_at": memory.created_at},
|
||||
}
|
||||
],
|
||||
|
|
@ -86,9 +89,11 @@ async def query_memory(
|
|||
if not memories:
|
||||
raise HTTPException(status_code=404, detail="No memories found for user")
|
||||
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)
|
||||
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
||||
vectors=[vector],
|
||||
limit=form_data.k,
|
||||
)
|
||||
|
||||
|
|
@ -105,21 +110,26 @@ async def reset_memory_from_vector_db(
|
|||
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
|
||||
memories = Memories.get_memories_by_user_id(user.id)
|
||||
|
||||
# Generate vectors in parallel
|
||||
vectors = await asyncio.gather(*[
|
||||
request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||
for memory in memories
|
||||
])
|
||||
|
||||
VECTOR_DB_CLIENT.upsert(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
items=[
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
),
|
||||
"vector": vectors[idx],
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
},
|
||||
}
|
||||
for memory in memories
|
||||
for idx, memory in enumerate(memories)
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -164,15 +174,17 @@ async def update_memory_by_id(
|
|||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
if form_data.content is not None:
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
)
|
||||
|
||||
VECTOR_DB_CLIENT.upsert(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
items=[
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
),
|
||||
"vector": vector,
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
|
|
|
|||
|
|
@ -1467,11 +1467,12 @@ def save_docs_to_vector_db(
|
|||
),
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
# Run async embedding in sync context
|
||||
embeddings = asyncio.run(embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts)),
|
||||
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
|
||||
user=user,
|
||||
)
|
||||
))
|
||||
log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items")
|
||||
|
||||
items = [
|
||||
|
|
@ -2262,7 +2263,7 @@ class QueryDocForm(BaseModel):
|
|||
|
||||
|
||||
@router.post("/query/doc")
|
||||
def query_doc_handler(
|
||||
async def query_doc_handler(
|
||||
request: Request,
|
||||
form_data: QueryDocForm,
|
||||
user=Depends(get_verified_user),
|
||||
|
|
@ -2275,7 +2276,7 @@ def query_doc_handler(
|
|||
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
||||
collection_name=form_data.collection_name
|
||||
)
|
||||
return query_doc_with_hybrid_search(
|
||||
return await query_doc_with_hybrid_search(
|
||||
collection_name=form_data.collection_name,
|
||||
collection_result=collection_results[form_data.collection_name],
|
||||
query=form_data.query,
|
||||
|
|
@ -2285,8 +2286,8 @@ def query_doc_handler(
|
|||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=(
|
||||
(
|
||||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
lambda query, documents: request.app.state.RERANKING_FUNCTION(
|
||||
query, documents, user=user
|
||||
)
|
||||
)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
|
|
@ -2307,11 +2308,12 @@ def query_doc_handler(
|
|||
user=user,
|
||||
)
|
||||
else:
|
||||
query_embedding = await request.app.state.EMBEDDING_FUNCTION(
|
||||
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
||||
)
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
||||
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
||||
),
|
||||
query_embedding=query_embedding,
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
user=user,
|
||||
)
|
||||
|
|
@ -2335,7 +2337,7 @@ class QueryCollectionsForm(BaseModel):
|
|||
|
||||
|
||||
@router.post("/query/collection")
|
||||
def query_collection_handler(
|
||||
async def query_collection_handler(
|
||||
request: Request,
|
||||
form_data: QueryCollectionsForm,
|
||||
user=Depends(get_verified_user),
|
||||
|
|
@ -2344,7 +2346,7 @@ def query_collection_handler(
|
|||
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (
|
||||
form_data.hybrid is None or form_data.hybrid
|
||||
):
|
||||
return query_collection_with_hybrid_search(
|
||||
return await query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
|
|
@ -2379,7 +2381,7 @@ def query_collection_handler(
|
|||
),
|
||||
)
|
||||
else:
|
||||
return query_collection(
|
||||
return await query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
|
|
@ -2461,7 +2463,7 @@ if ENV == "dev":
|
|||
@router.get("/ef/{text}")
|
||||
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
||||
return {
|
||||
"result": request.app.state.EMBEDDING_FUNCTION(
|
||||
"result": await request.app.state.EMBEDDING_FUNCTION(
|
||||
text, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -993,37 +993,32 @@ async def chat_completion_files_handler(
|
|||
queries = [get_last_user_message(body["messages"])]
|
||||
|
||||
try:
|
||||
# Offload get_sources_from_items to a separate thread
|
||||
loop = asyncio.get_running_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
sources = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: get_sources_from_items(
|
||||
request=request,
|
||||
items=files,
|
||||
queries=queries,
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=(
|
||||
(
|
||||
lambda query, documents: request.app.state.RERANKING_FUNCTION(
|
||||
query, documents, user=user
|
||||
)
|
||||
)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
else None
|
||||
),
|
||||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
full_context=all_full_context
|
||||
or request.app.state.config.RAG_FULL_CONTEXT,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
# Directly await async get_sources_from_items (no thread needed - fully async now)
|
||||
sources = await get_sources_from_items(
|
||||
request=request,
|
||||
items=files,
|
||||
queries=queries,
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=(
|
||||
(
|
||||
lambda query, documents: request.app.state.RERANKING_FUNCTION(
|
||||
query, documents, user=user
|
||||
)
|
||||
)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
else None
|
||||
),
|
||||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
full_context=all_full_context
|
||||
or request.app.state.config.RAG_FULL_CONTEXT,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue