mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
refac
Some checks are pending
Deploy to HuggingFace Spaces / check-secret (push) Waiting to run
Deploy to HuggingFace Spaces / deploy (push) Blocked by required conditions
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / merge-main-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-cuda-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-cuda126-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-ollama-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-slim-images (push) Blocked by required conditions
Python CI / Format Backend (push) Waiting to run
Frontend Build / Format & Build Frontend (push) Waiting to run
Frontend Build / Frontend Unit Tests (push) Waiting to run
Some checks are pending
Deploy to HuggingFace Spaces / check-secret (push) Waiting to run
Deploy to HuggingFace Spaces / deploy (push) Blocked by required conditions
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Waiting to run
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Waiting to run
Create and publish Docker images with specific build args / merge-main-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-cuda-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-cuda126-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-ollama-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-slim-images (push) Blocked by required conditions
Python CI / Format Backend (push) Waiting to run
Frontend Build / Format & Build Frontend (push) Waiting to run
Frontend Build / Frontend Unit Tests (push) Waiting to run
This commit is contained in:
parent
902c6cfbea
commit
9bfc414d26
1 changed files with 274 additions and 287 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Union
|
from typing import Awaitable, Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -410,7 +410,9 @@ async def query_collection(
|
||||||
return None, e
|
return None, e
|
||||||
|
|
||||||
# Generate all query embeddings (in one call)
|
# Generate all query embeddings (in one call)
|
||||||
query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
query_embeddings = await embedding_function(
|
||||||
|
queries, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
||||||
|
)
|
||||||
log.debug(
|
log.debug(
|
||||||
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
||||||
)
|
)
|
||||||
|
|
@ -491,14 +493,16 @@ async def query_collection_with_hybrid_search(
|
||||||
# Prepare tasks for all collections and queries
|
# Prepare tasks for all collections and queries
|
||||||
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
|
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
|
||||||
tasks = [
|
tasks = [
|
||||||
(cn, q)
|
(collection_name, query)
|
||||||
for cn in collection_names
|
for collection_name in collection_names
|
||||||
if collection_results[cn] is not None
|
if collection_results[collection_name] is not None
|
||||||
for q in queries
|
for query in queries
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run all queries in parallel using asyncio.gather
|
# Run all queries in parallel using asyncio.gather
|
||||||
task_results = await asyncio.gather(*[process_query(cn, q) for cn, q in tasks])
|
task_results = await asyncio.gather(
|
||||||
|
*[process_query(collection_name, query) for collection_name, query in tasks]
|
||||||
|
)
|
||||||
|
|
||||||
for result, err in task_results:
|
for result, err in task_results:
|
||||||
if err is not None:
|
if err is not None:
|
||||||
|
|
@ -514,7 +518,7 @@ async def query_collection_with_hybrid_search(
|
||||||
return merge_and_sort_query_results(results, k=k)
|
return merge_and_sort_query_results(results, k=k)
|
||||||
|
|
||||||
|
|
||||||
async def generate_openai_batch_embeddings_async(
|
def generate_openai_batch_embeddings(
|
||||||
model: str,
|
model: str,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
url: str = "https://api.openai.com/v1",
|
url: str = "https://api.openai.com/v1",
|
||||||
|
|
@ -524,7 +528,52 @@ async def generate_openai_batch_embeddings_async(
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"generate_openai_batch_embeddings_async:model {model} batch size: {len(texts)}"
|
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||||
|
)
|
||||||
|
json_data = {"input": texts, "model": model}
|
||||||
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/embeddings",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "data" in data:
|
||||||
|
return [elem["embedding"] for elem in data["data"]]
|
||||||
|
else:
|
||||||
|
raise "Something went wrong :/"
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def agenerate_openai_batch_embeddings(
|
||||||
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str = "https://api.openai.com/v1",
|
||||||
|
key: str = "",
|
||||||
|
prefix: str = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
log.debug(
|
||||||
|
f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||||
)
|
)
|
||||||
form_data = {"input": texts, "model": model}
|
form_data = {"input": texts, "model": model}
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
|
@ -537,8 +586,10 @@ async def generate_openai_batch_embeddings_async(
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
headers = include_user_info_headers(headers, user)
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.post(f"{url}/embeddings", headers=headers, json=form_data) as r:
|
async with session.post(
|
||||||
|
f"{url}/embeddings", headers=headers, json=form_data
|
||||||
|
) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
data = await r.json()
|
data = await r.json()
|
||||||
if "data" in data:
|
if "data" in data:
|
||||||
|
|
@ -550,7 +601,7 @@ async def generate_openai_batch_embeddings_async(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def generate_azure_openai_batch_embeddings_async(
|
def generate_azure_openai_batch_embeddings(
|
||||||
model: str,
|
model: str,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
url: str,
|
url: str,
|
||||||
|
|
@ -561,7 +612,61 @@ async def generate_azure_openai_batch_embeddings_async(
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"generate_azure_openai_batch_embeddings_async:deployment {model} batch size: {len(texts)}"
|
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
||||||
|
)
|
||||||
|
json_data = {"input": texts}
|
||||||
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||||
|
|
||||||
|
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
r = requests.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": key,
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
if r.status_code == 429:
|
||||||
|
retry = float(r.headers.get("Retry-After", "1"))
|
||||||
|
time.sleep(retry)
|
||||||
|
continue
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "data" in data:
|
||||||
|
return [elem["embedding"] for elem in data["data"]]
|
||||||
|
else:
|
||||||
|
raise Exception("Something went wrong :/")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def agenerate_azure_openai_batch_embeddings(
|
||||||
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str,
|
||||||
|
key: str = "",
|
||||||
|
version: str = "",
|
||||||
|
prefix: str = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
log.debug(
|
||||||
|
f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
||||||
)
|
)
|
||||||
form_data = {"input": texts}
|
form_data = {"input": texts}
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
|
@ -576,7 +681,7 @@ async def generate_azure_openai_batch_embeddings_async(
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
headers = include_user_info_headers(headers, user)
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.post(full_url, headers=headers, json=form_data) as r:
|
async with session.post(full_url, headers=headers, json=form_data) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
data = await r.json()
|
data = await r.json()
|
||||||
|
|
@ -589,7 +694,7 @@ async def generate_azure_openai_batch_embeddings_async(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def generate_ollama_batch_embeddings_async(
|
def generate_ollama_batch_embeddings(
|
||||||
model: str,
|
model: str,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
url: str,
|
url: str,
|
||||||
|
|
@ -599,7 +704,53 @@ async def generate_ollama_batch_embeddings_async(
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"generate_ollama_batch_embeddings_async:model {model} batch size: {len(texts)}"
|
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||||
|
)
|
||||||
|
json_data = {"input": texts, "model": model}
|
||||||
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/api/embed",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
|
||||||
|
if "embeddings" in data:
|
||||||
|
return data["embeddings"]
|
||||||
|
else:
|
||||||
|
raise "Something went wrong :/"
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def agenerate_ollama_batch_embeddings(
|
||||||
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str,
|
||||||
|
key: str = "",
|
||||||
|
prefix: str = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
log.debug(
|
||||||
|
f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||||
)
|
)
|
||||||
form_data = {"input": texts, "model": model}
|
form_data = {"input": texts, "model": model}
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
|
@ -612,8 +763,10 @@ async def generate_ollama_batch_embeddings_async(
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
headers = include_user_info_headers(headers, user)
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.post(f"{url}/api/embed", headers=headers, json=form_data) as r:
|
async with session.post(
|
||||||
|
f"{url}/api/embed", headers=headers, json=form_data
|
||||||
|
) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
data = await r.json()
|
data = await r.json()
|
||||||
if "embeddings" in data:
|
if "embeddings" in data:
|
||||||
|
|
@ -625,35 +778,6 @@ async def generate_ollama_batch_embeddings_async(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def generate_multiple_async(query, prefix, user, func_async, embedding_batch_size):
|
|
||||||
if isinstance(query, list):
|
|
||||||
# Create batches
|
|
||||||
batches = [
|
|
||||||
query[i : i + embedding_batch_size]
|
|
||||||
for i in range(0, len(query), embedding_batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
log.debug(f"generate_multiple_async: Processing {len(batches)} batches in parallel")
|
|
||||||
|
|
||||||
# Execute all batches in parallel
|
|
||||||
tasks = [
|
|
||||||
func_async(batch, prefix=prefix, user=user)
|
|
||||||
for batch in batches
|
|
||||||
]
|
|
||||||
batch_results = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
# Flatten results
|
|
||||||
embeddings = []
|
|
||||||
for batch_embeddings in batch_results:
|
|
||||||
if isinstance(batch_embeddings, list):
|
|
||||||
embeddings.extend(batch_embeddings)
|
|
||||||
|
|
||||||
log.debug(f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches")
|
|
||||||
return embeddings
|
|
||||||
else:
|
|
||||||
return await func_async([query], prefix=prefix, user=user)
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_function(
|
def get_embedding_function(
|
||||||
embedding_engine,
|
embedding_engine,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
|
@ -662,64 +786,117 @@ def get_embedding_function(
|
||||||
key,
|
key,
|
||||||
embedding_batch_size,
|
embedding_batch_size,
|
||||||
azure_api_version=None,
|
azure_api_version=None,
|
||||||
):
|
) -> Awaitable:
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
# Sentence transformers: CPU-bound sync operation
|
# Sentence transformers: CPU-bound sync operation
|
||||||
# Run in thread pool to prevent blocking async event loop
|
async def async_embedding_function(query, prefix=None, user=None):
|
||||||
def sync_encode(query, prefix=None):
|
return await asyncio.to_thread(
|
||||||
return embedding_function.encode(
|
(
|
||||||
query, **({"prompt": prefix} if prefix else {})
|
lambda query, prefix=None: embedding_function.encode(
|
||||||
).tolist()
|
query, **({"prompt": prefix} if prefix else {})
|
||||||
|
).tolist()
|
||||||
|
),
|
||||||
|
query,
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_wrapper(query, prefix=None, user=None):
|
return async_embedding_function
|
||||||
# Run CPU-bound operation in thread pool
|
|
||||||
return await asyncio.to_thread(sync_encode, query, prefix)
|
|
||||||
|
|
||||||
return async_wrapper
|
|
||||||
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
||||||
# Create async function based on engine
|
embedding_function = lambda query, prefix=None, user=None: generate_embeddings(
|
||||||
if embedding_engine == "openai":
|
engine=embedding_engine,
|
||||||
func_async = lambda texts, prefix=None, user=None: generate_openai_batch_embeddings_async(
|
model=embedding_model,
|
||||||
model=embedding_model,
|
text=query,
|
||||||
texts=texts,
|
prefix=prefix,
|
||||||
url=url,
|
url=url,
|
||||||
key=key,
|
key=key,
|
||||||
prefix=prefix,
|
user=user,
|
||||||
user=user,
|
azure_api_version=azure_api_version,
|
||||||
)
|
)
|
||||||
elif embedding_engine == "azure_openai":
|
|
||||||
func_async = lambda texts, prefix=None, user=None: generate_azure_openai_batch_embeddings_async(
|
|
||||||
model=embedding_model,
|
|
||||||
texts=texts,
|
|
||||||
url=url,
|
|
||||||
key=key,
|
|
||||||
version=azure_api_version,
|
|
||||||
prefix=prefix,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
elif embedding_engine == "ollama":
|
|
||||||
func_async = lambda texts, prefix=None, user=None: generate_ollama_batch_embeddings_async(
|
|
||||||
model=embedding_model,
|
|
||||||
texts=texts,
|
|
||||||
url=url,
|
|
||||||
key=key,
|
|
||||||
prefix=prefix,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return async function for parallel batch processing (FastAPI-compatible)
|
async def async_embedding_function(query, prefix, user):
|
||||||
async def embedding_wrapper(query, prefix=None, user=None):
|
|
||||||
if isinstance(query, list):
|
if isinstance(query, list):
|
||||||
return await generate_multiple_async(query, prefix, user, func_async, embedding_batch_size)
|
# Create batches
|
||||||
else:
|
batches = [
|
||||||
result = await func_async([query], prefix=prefix, user=user)
|
query[i : i + embedding_batch_size]
|
||||||
return result[0] if result else None
|
for i in range(0, len(query), embedding_batch_size)
|
||||||
|
]
|
||||||
|
log.debug(
|
||||||
|
f"generate_multiple_async: Processing {len(batches)} batches in parallel"
|
||||||
|
)
|
||||||
|
|
||||||
return embedding_wrapper
|
# Execute all batches in parallel
|
||||||
|
tasks = [
|
||||||
|
embedding_function(batch, prefix=prefix, user=user)
|
||||||
|
for batch in batches
|
||||||
|
]
|
||||||
|
batch_results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Flatten results
|
||||||
|
embeddings = []
|
||||||
|
for batch_embeddings in batch_results:
|
||||||
|
if isinstance(batch_embeddings, list):
|
||||||
|
embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches"
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
else:
|
||||||
|
return await embedding_function(query, prefix, user)
|
||||||
|
|
||||||
|
return async_embedding_function
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_embeddings(
|
||||||
|
engine: str,
|
||||||
|
model: str,
|
||||||
|
text: Union[str, list[str]],
|
||||||
|
prefix: Union[str, None] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
url = kwargs.get("url", "")
|
||||||
|
key = kwargs.get("key", "")
|
||||||
|
user = kwargs.get("user")
|
||||||
|
|
||||||
|
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
|
||||||
|
if isinstance(text, list):
|
||||||
|
text = [f"{prefix}{text_element}" for text_element in text]
|
||||||
|
else:
|
||||||
|
text = f"{prefix}{text}"
|
||||||
|
|
||||||
|
if engine == "ollama":
|
||||||
|
embeddings = await agenerate_ollama_batch_embeddings(
|
||||||
|
**{
|
||||||
|
"model": model,
|
||||||
|
"texts": text if isinstance(text, list) else [text],
|
||||||
|
"url": url,
|
||||||
|
"key": key,
|
||||||
|
"prefix": prefix,
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
elif engine == "openai":
|
||||||
|
embeddings = await agenerate_openai_batch_embeddings(
|
||||||
|
model, text if isinstance(text, list) else [text], url, key, prefix, user
|
||||||
|
)
|
||||||
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
elif engine == "azure_openai":
|
||||||
|
azure_api_version = kwargs.get("azure_api_version", "")
|
||||||
|
embeddings = await agenerate_azure_openai_batch_embeddings(
|
||||||
|
model,
|
||||||
|
text if isinstance(text, list) else [text],
|
||||||
|
url,
|
||||||
|
key,
|
||||||
|
azure_api_version,
|
||||||
|
prefix,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
|
||||||
|
|
||||||
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
|
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
|
||||||
if reranking_function is None:
|
if reranking_function is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -1055,198 +1232,6 @@ def get_model_path(model: str, update_model: bool = False):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def generate_openai_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str = "https://api.openai.com/v1",
|
|
||||||
key: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
json_data = {"input": texts, "model": model}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/embeddings",
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}",
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
if "data" in data:
|
|
||||||
return [elem["embedding"] for elem in data["data"]]
|
|
||||||
else:
|
|
||||||
raise "Something went wrong :/"
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def generate_azure_openai_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str,
|
|
||||||
key: str = "",
|
|
||||||
version: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
json_data = {"input": texts}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
r = requests.post(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"api-key": key,
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
if r.status_code == 429:
|
|
||||||
retry = float(r.headers.get("Retry-After", "1"))
|
|
||||||
time.sleep(retry)
|
|
||||||
continue
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
if "data" in data:
|
|
||||||
return [elem["embedding"] for elem in data["data"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Something went wrong :/")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def generate_ollama_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str,
|
|
||||||
key: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
json_data = {"input": texts, "model": model}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/api/embed",
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}",
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
|
|
||||||
if "embeddings" in data:
|
|
||||||
return data["embeddings"]
|
|
||||||
else:
|
|
||||||
raise "Something went wrong :/"
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings(
|
|
||||||
engine: str,
|
|
||||||
model: str,
|
|
||||||
text: Union[str, list[str]],
|
|
||||||
prefix: Union[str, None] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
url = kwargs.get("url", "")
|
|
||||||
key = kwargs.get("key", "")
|
|
||||||
user = kwargs.get("user")
|
|
||||||
|
|
||||||
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
|
|
||||||
if isinstance(text, list):
|
|
||||||
text = [f"{prefix}{text_element}" for text_element in text]
|
|
||||||
else:
|
|
||||||
text = f"{prefix}{text}"
|
|
||||||
|
|
||||||
# Run async embedding generation synchronously
|
|
||||||
if engine == "ollama":
|
|
||||||
embeddings = asyncio.run(generate_ollama_batch_embeddings_async(
|
|
||||||
model=model,
|
|
||||||
texts=text if isinstance(text, list) else [text],
|
|
||||||
url=url,
|
|
||||||
key=key,
|
|
||||||
prefix=prefix,
|
|
||||||
user=user,
|
|
||||||
))
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
|
||||||
elif engine == "openai":
|
|
||||||
embeddings = asyncio.run(generate_openai_batch_embeddings_async(
|
|
||||||
model, text if isinstance(text, list) else [text], url, key, prefix, user
|
|
||||||
))
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
|
||||||
elif engine == "azure_openai":
|
|
||||||
azure_api_version = kwargs.get("azure_api_version", "")
|
|
||||||
embeddings = asyncio.run(generate_azure_openai_batch_embeddings_async(
|
|
||||||
model,
|
|
||||||
text if isinstance(text, list) else [text],
|
|
||||||
url,
|
|
||||||
key,
|
|
||||||
azure_api_version,
|
|
||||||
prefix,
|
|
||||||
user,
|
|
||||||
))
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
|
||||||
|
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
|
@ -1280,7 +1265,9 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||||
else:
|
else:
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
|
|
||||||
query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
query_embedding = await self.embedding_function(
|
||||||
|
query, RAG_EMBEDDING_QUERY_PREFIX
|
||||||
|
)
|
||||||
document_embedding = await self.embedding_function(
|
document_embedding = await self.embedding_function(
|
||||||
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
|
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue