perf: 50x performance improvement for external embeddings (#19296)

* Update utils.py (#77)

Co-authored-by: Claude <noreply@anthropic.com>

* refactor: address code review feedback for embedding performance improvements (#92)

Co-authored-by: Claude <noreply@anthropic.com>

* fix: prevent sentence transformers from blocking async event loop (#95)

Co-authored-by: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Classic298 2025-11-23 02:54:59 +01:00 committed by GitHub
parent bb3e222e09
commit 902c6cfbea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 289 additions and 118 deletions

View file

@ -3,6 +3,8 @@ import os
from typing import Optional, Union from typing import Optional, Union
import requests import requests
import aiohttp
import asyncio
import hashlib import hashlib
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import time import time
@ -27,6 +29,7 @@ from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.misc import get_message_list from open_webui.utils.misc import get_message_list
from open_webui.retrieval.web.utils import get_web_loader from open_webui.retrieval.web.utils import get_web_loader
@ -87,15 +90,16 @@ class VectorSearchRetriever(BaseRetriever):
embedding_function: Any embedding_function: Any
top_k: int top_k: int
def _get_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]: ) -> list[Document]:
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
result = VECTOR_DB_CLIENT.search( result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)], vectors=[embedding],
limit=self.top_k, limit=self.top_k,
) )
@ -186,7 +190,7 @@ def get_enriched_texts(collection_result: GetResult) -> list[str]:
return enriched_texts return enriched_texts
def query_doc_with_hybrid_search( async def query_doc_with_hybrid_search(
collection_name: str, collection_name: str,
collection_result: GetResult, collection_result: GetResult,
query: str, query: str,
@ -262,7 +266,7 @@ def query_doc_with_hybrid_search(
base_compressor=compressor, base_retriever=ensemble_retriever base_compressor=compressor, base_retriever=ensemble_retriever
) )
result = compression_retriever.invoke(query) result = await compression_retriever.ainvoke(query)
distances = [d.metadata.get("score") for d in result] distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result] documents = [d.page_content for d in result]
@ -381,7 +385,7 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
return merge_get_results(results) return merge_get_results(results)
def query_collection( async def query_collection(
collection_names: list[str], collection_names: list[str],
queries: list[str], queries: list[str],
embedding_function, embedding_function,
@ -406,7 +410,7 @@ def query_collection(
return None, e return None, e
# Generate all query embeddings (in one call) # Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug( log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
) )
@ -433,7 +437,7 @@ def query_collection(
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search( async def query_collection_with_hybrid_search(
collection_names: list[str], collection_names: list[str],
queries: list[str], queries: list[str],
embedding_function, embedding_function,
@ -465,9 +469,9 @@ def query_collection_with_hybrid_search(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..." f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
) )
def process_query(collection_name, query): async def process_query(collection_name, query):
try: try:
result = query_doc_with_hybrid_search( result = await query_doc_with_hybrid_search(
collection_name=collection_name, collection_name=collection_name,
collection_result=collection_results[collection_name], collection_result=collection_results[collection_name],
query=query, query=query,
@ -493,9 +497,8 @@ def query_collection_with_hybrid_search(
for q in queries for q in queries
] ]
with ThreadPoolExecutor() as executor: # Run all queries in parallel using asyncio.gather
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks] task_results = await asyncio.gather(*[process_query(cn, q) for cn, q in tasks])
task_results = [future.result() for future in future_results]
for result, err in task_results: for result, err in task_results:
if err is not None: if err is not None:
@ -511,6 +514,146 @@ def query_collection_with_hybrid_search(
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
async def generate_openai_batch_embeddings_async(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_openai_batch_embeddings_async:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession() as session:
async with session.post(f"{url}/embeddings", headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
async def generate_azure_openai_batch_embeddings_async(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings_async:deployment {model} batch size: {len(texts)}"
)
form_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession() as session:
async with session.post(full_url, headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
async def generate_ollama_batch_embeddings_async(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_ollama_batch_embeddings_async:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession() as session:
async with session.post(f"{url}/api/embed", headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
async def generate_multiple_async(query, prefix, user, func_async, embedding_batch_size):
if isinstance(query, list):
# Create batches
batches = [
query[i : i + embedding_batch_size]
for i in range(0, len(query), embedding_batch_size)
]
log.debug(f"generate_multiple_async: Processing {len(batches)} batches in parallel")
# Execute all batches in parallel
tasks = [
func_async(batch, prefix=prefix, user=user)
for batch in batches
]
batch_results = await asyncio.gather(*tasks)
# Flatten results
embeddings = []
for batch_embeddings in batch_results:
if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings)
log.debug(f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches")
return embeddings
else:
return await func_async([query], prefix=prefix, user=user)
def get_embedding_function( def get_embedding_function(
embedding_engine, embedding_engine,
embedding_model, embedding_model,
@ -521,40 +664,58 @@ def get_embedding_function(
azure_api_version=None, azure_api_version=None,
): ):
if embedding_engine == "": if embedding_engine == "":
return lambda query, prefix=None, user=None: embedding_function.encode( # Sentence transformers: CPU-bound sync operation
query, **({"prompt": prefix} if prefix else {}) # Run in thread pool to prevent blocking async event loop
).tolist() def sync_encode(query, prefix=None):
return embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
async def async_wrapper(query, prefix=None, user=None):
# Run CPU-bound operation in thread pool
return await asyncio.to_thread(sync_encode, query, prefix)
return async_wrapper
elif embedding_engine in ["ollama", "openai", "azure_openai"]: elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings( # Create async function based on engine
engine=embedding_engine, if embedding_engine == "openai":
model=embedding_model, func_async = lambda texts, prefix=None, user=None: generate_openai_batch_embeddings_async(
text=query, model=embedding_model,
prefix=prefix, texts=texts,
url=url, url=url,
key=key, key=key,
user=user, prefix=prefix,
azure_api_version=azure_api_version, user=user,
) )
elif embedding_engine == "azure_openai":
func_async = lambda texts, prefix=None, user=None: generate_azure_openai_batch_embeddings_async(
model=embedding_model,
texts=texts,
url=url,
key=key,
version=azure_api_version,
prefix=prefix,
user=user,
)
elif embedding_engine == "ollama":
func_async = lambda texts, prefix=None, user=None: generate_ollama_batch_embeddings_async(
model=embedding_model,
texts=texts,
url=url,
key=key,
prefix=prefix,
user=user,
)
def generate_multiple(query, prefix, user, func): # Return async function for parallel batch processing (FastAPI-compatible)
async def embedding_wrapper(query, prefix=None, user=None):
if isinstance(query, list): if isinstance(query, list):
embeddings = [] return await generate_multiple_async(query, prefix, user, func_async, embedding_batch_size)
for i in range(0, len(query), embedding_batch_size):
batch_embeddings = func(
query[i : i + embedding_batch_size],
prefix=prefix,
user=user,
)
if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings)
return embeddings
else: else:
return func(query, prefix, user) result = await func_async([query], prefix=prefix, user=user)
return result[0] if result else None
return lambda query, prefix=None, user=None: generate_multiple( return embedding_wrapper
query, prefix, user, func
)
else: else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}") raise ValueError(f"Unknown embedding engine: {embedding_engine}")
@ -572,7 +733,7 @@ def get_reranking_function(reranking_engine, reranking_model, reranking_function
) )
def get_sources_from_items( async def get_sources_from_items(
request, request,
items, items,
queries, queries,
@ -800,7 +961,7 @@ def get_sources_from_items(
query_result = None # Initialize to None query_result = None # Initialize to None
if hybrid_search: if hybrid_search:
try: try:
query_result = query_collection_with_hybrid_search( query_result = await query_collection_with_hybrid_search(
collection_names=collection_names, collection_names=collection_names,
queries=queries, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
@ -818,7 +979,7 @@ def get_sources_from_items(
# fallback to non-hybrid search # fallback to non-hybrid search
if not hybrid_search and query_result is None: if not hybrid_search and query_result is None:
query_result = query_collection( query_result = await query_collection(
collection_names=collection_names, collection_names=collection_names,
queries=queries, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
@ -1056,26 +1217,25 @@ def generate_embeddings(
else: else:
text = f"{prefix}{text}" text = f"{prefix}{text}"
# Run async embedding generation synchronously
if engine == "ollama": if engine == "ollama":
embeddings = generate_ollama_batch_embeddings( embeddings = asyncio.run(generate_ollama_batch_embeddings_async(
**{ model=model,
"model": model, texts=text if isinstance(text, list) else [text],
"texts": text if isinstance(text, list) else [text], url=url,
"url": url, key=key,
"key": key, prefix=prefix,
"prefix": prefix, user=user,
"user": user, ))
}
)
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai": elif engine == "openai":
embeddings = generate_openai_batch_embeddings( embeddings = asyncio.run(generate_openai_batch_embeddings_async(
model, text if isinstance(text, list) else [text], url, key, prefix, user model, text if isinstance(text, list) else [text], url, key, prefix, user
) ))
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai": elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "") azure_api_version = kwargs.get("azure_api_version", "")
embeddings = generate_azure_openai_batch_embeddings( embeddings = asyncio.run(generate_azure_openai_batch_embeddings_async(
model, model,
text if isinstance(text, list) else [text], text if isinstance(text, list) else [text],
url, url,
@ -1083,7 +1243,7 @@ def generate_embeddings(
azure_api_version, azure_api_version,
prefix, prefix,
user, user,
) ))
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings
@ -1104,7 +1264,7 @@ class RerankCompressor(BaseDocumentCompressor):
extra = "forbid" extra = "forbid"
arbitrary_types_allowed = True arbitrary_types_allowed = True
def compress_documents( async def acompress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],
query: str, query: str,
@ -1114,12 +1274,14 @@ class RerankCompressor(BaseDocumentCompressor):
scores = None scores = None
if reranking: if reranking:
scores = self.reranking_function(query, documents) scores = self.reranking_function(
[(query, doc.page_content) for doc in documents]
)
else: else:
from sentence_transformers import util from sentence_transformers import util
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
document_embedding = self.embedding_function( document_embedding = await self.embedding_function(
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
) )
scores = util.cos_sim(query_embedding, document_embedding)[0] scores = util.cos_sim(query_embedding, document_embedding)[0]

View file

@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
import asyncio
from typing import Optional from typing import Optional
from open_webui.models.memories import Memories, MemoryModel from open_webui.models.memories import Memories, MemoryModel
@ -17,7 +18,7 @@ router = APIRouter()
@router.get("/ef") @router.get("/ef")
async def get_embeddings(request: Request): async def get_embeddings(request: Request):
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")} return {"result": await request.app.state.EMBEDDING_FUNCTION("hello world")}
############################ ############################
@ -51,15 +52,17 @@ async def add_memory(
): ):
memory = Memories.insert_new_memory(user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content)
vector = await request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
)
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
items=[ items=[
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": vector,
memory.content, user=user
),
"metadata": {"created_at": memory.created_at}, "metadata": {"created_at": memory.created_at},
} }
], ],
@ -86,9 +89,11 @@ async def query_memory(
if not memories: if not memories:
raise HTTPException(status_code=404, detail="No memories found for user") raise HTTPException(status_code=404, detail="No memories found for user")
vector = await request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)
results = VECTOR_DB_CLIENT.search( results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)], vectors=[vector],
limit=form_data.k, limit=form_data.k,
) )
@ -105,21 +110,26 @@ async def reset_memory_from_vector_db(
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id) memories = Memories.get_memories_by_user_id(user.id)
# Generate vectors in parallel
vectors = await asyncio.gather(*[
request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
for memory in memories
])
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
items=[ items=[
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": vectors[idx],
memory.content, user=user
),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,
"updated_at": memory.updated_at, "updated_at": memory.updated_at,
}, },
} }
for memory in memories for idx, memory in enumerate(memories)
], ],
) )
@ -164,15 +174,17 @@ async def update_memory_by_id(
raise HTTPException(status_code=404, detail="Memory not found") raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None: if form_data.content is not None:
vector = await request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
)
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
items=[ items=[
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": vector,
memory.content, user=user
),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,
"updated_at": memory.updated_at, "updated_at": memory.updated_at,

View file

@ -1467,11 +1467,12 @@ def save_docs_to_vector_db(
), ),
) )
embeddings = embedding_function( # Run async embedding in sync context
embeddings = asyncio.run(embedding_function(
list(map(lambda x: x.replace("\n", " "), texts)), list(map(lambda x: x.replace("\n", " "), texts)),
prefix=RAG_EMBEDDING_CONTENT_PREFIX, prefix=RAG_EMBEDDING_CONTENT_PREFIX,
user=user, user=user,
) ))
log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items") log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items")
items = [ items = [
@ -2262,7 +2263,7 @@ class QueryDocForm(BaseModel):
@router.post("/query/doc") @router.post("/query/doc")
def query_doc_handler( async def query_doc_handler(
request: Request, request: Request,
form_data: QueryDocForm, form_data: QueryDocForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
@ -2275,7 +2276,7 @@ def query_doc_handler(
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
collection_name=form_data.collection_name collection_name=form_data.collection_name
) )
return query_doc_with_hybrid_search( return await query_doc_with_hybrid_search(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
collection_result=collection_results[form_data.collection_name], collection_result=collection_results[form_data.collection_name],
query=form_data.query, query=form_data.query,
@ -2285,8 +2286,8 @@ def query_doc_handler(
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=( reranking_function=(
( (
lambda sentences: request.app.state.RERANKING_FUNCTION( lambda query, documents: request.app.state.RERANKING_FUNCTION(
sentences, user=user query, documents, user=user
) )
) )
if request.app.state.RERANKING_FUNCTION if request.app.state.RERANKING_FUNCTION
@ -2307,11 +2308,12 @@ def query_doc_handler(
user=user, user=user,
) )
else: else:
query_embedding = await request.app.state.EMBEDDING_FUNCTION(
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
)
return query_doc( return query_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION( query_embedding=query_embedding,
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
user=user, user=user,
) )
@ -2335,7 +2337,7 @@ class QueryCollectionsForm(BaseModel):
@router.post("/query/collection") @router.post("/query/collection")
def query_collection_handler( async def query_collection_handler(
request: Request, request: Request,
form_data: QueryCollectionsForm, form_data: QueryCollectionsForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
@ -2344,7 +2346,7 @@ def query_collection_handler(
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (
form_data.hybrid is None or form_data.hybrid form_data.hybrid is None or form_data.hybrid
): ):
return query_collection_with_hybrid_search( return await query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
@ -2379,7 +2381,7 @@ def query_collection_handler(
), ),
) )
else: else:
return query_collection( return await query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
@ -2461,7 +2463,7 @@ if ENV == "dev":
@router.get("/ef/{text}") @router.get("/ef/{text}")
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
return { return {
"result": request.app.state.EMBEDDING_FUNCTION( "result": await request.app.state.EMBEDDING_FUNCTION(
text, prefix=RAG_EMBEDDING_QUERY_PREFIX text, prefix=RAG_EMBEDDING_QUERY_PREFIX
) )
} }

View file

@ -993,37 +993,32 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])] queries = [get_last_user_message(body["messages"])]
try: try:
# Offload get_sources_from_items to a separate thread # Directly await async get_sources_from_items (no thread needed - fully async now)
loop = asyncio.get_running_loop() sources = await get_sources_from_items(
with ThreadPoolExecutor() as executor: request=request,
sources = await loop.run_in_executor( items=files,
executor, queries=queries,
lambda: get_sources_from_items( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
request=request, query, prefix=prefix, user=user
items=files, ),
queries=queries, k=request.app.state.config.TOP_K,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( reranking_function=(
query, prefix=prefix, user=user (
), lambda query, documents: request.app.state.RERANKING_FUNCTION(
k=request.app.state.config.TOP_K, query, documents, user=user
reranking_function=( )
( )
lambda query, documents: request.app.state.RERANKING_FUNCTION( if request.app.state.RERANKING_FUNCTION
query, documents, user=user else None
) ),
) k_reranker=request.app.state.config.TOP_K_RERANKER,
if request.app.state.RERANKING_FUNCTION r=request.app.state.config.RELEVANCE_THRESHOLD,
else None hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
), hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
k_reranker=request.app.state.config.TOP_K_RERANKER, full_context=all_full_context
r=request.app.state.config.RELEVANCE_THRESHOLD, or request.app.state.config.RAG_FULL_CONTEXT,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, user=user,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, )
full_context=all_full_context
or request.app.state.config.RAG_FULL_CONTEXT,
user=user,
),
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)