mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
Merge pull request #14667 from hdnh2006/main
feat: OpenAI-compatible `/api/embeddings` endpoint with provider-agnostic OpenWebUI architecture
This commit is contained in:
commit
14e158fde9
5 changed files with 321 additions and 0 deletions
|
|
@ -413,6 +413,7 @@ from open_webui.utils.chat import (
|
||||||
chat_completed as chat_completed_handler,
|
chat_completed as chat_completed_handler,
|
||||||
chat_action as chat_action_handler,
|
chat_action as chat_action_handler,
|
||||||
)
|
)
|
||||||
|
from open_webui.utils.embeddings import generate_embeddings
|
||||||
from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
|
|
@ -1545,6 +1546,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
|
||||||
async def get_app_changelog():
|
async def get_app_changelog():
|
||||||
return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
|
return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
|
||||||
|
|
||||||
|
##################################
|
||||||
|
# Embeddings
|
||||||
|
##################################
|
||||||
|
|
||||||
|
@app.post("/api/embeddings")
|
||||||
|
async def embeddings_endpoint(
|
||||||
|
request: Request,
|
||||||
|
form_data: dict,
|
||||||
|
user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible embeddings endpoint.
|
||||||
|
|
||||||
|
This handler:
|
||||||
|
- Performs user/model checks and dispatches to the correct backend.
|
||||||
|
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): Request context.
|
||||||
|
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
|
||||||
|
user (UserModel): Authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: OpenAI-compatible embeddings response.
|
||||||
|
"""
|
||||||
|
# Make sure models are loaded in app state
|
||||||
|
if not request.app.state.MODELS:
|
||||||
|
await get_all_models(request, user=user)
|
||||||
|
# Use generic dispatcher in utils.embeddings
|
||||||
|
return await generate_embeddings(request, form_data, user)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# OAuth Login & Callback
|
# OAuth Login & Callback
|
||||||
|
|
|
||||||
|
|
@ -886,6 +886,85 @@ async def generate_chat_completion(
|
||||||
r.close()
|
r.close()
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
async def embeddings(request: Request, form_data: dict, user):
|
||||||
|
"""
|
||||||
|
Calls the embeddings endpoint for OpenAI-compatible providers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The FastAPI request context.
|
||||||
|
form_data (dict): OpenAI-compatible embeddings payload.
|
||||||
|
user (UserModel): The authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: OpenAI-compatible embeddings response.
|
||||||
|
"""
|
||||||
|
idx = 0
|
||||||
|
# Prepare payload/body
|
||||||
|
body = json.dumps(form_data)
|
||||||
|
# Find correct backend url/key based on model
|
||||||
|
await get_all_models(request, user=user)
|
||||||
|
model_id = form_data.get("model")
|
||||||
|
models = request.app.state.OPENAI_MODELS
|
||||||
|
if model_id in models:
|
||||||
|
idx = models[model_id]["urlIdx"]
|
||||||
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||||
|
r = None
|
||||||
|
session = None
|
||||||
|
streaming = False
|
||||||
|
try:
|
||||||
|
session = aiohttp.ClientSession(trust_env=True)
|
||||||
|
r = await session.request(
|
||||||
|
method="POST",
|
||||||
|
url=f"{url}/embeddings",
|
||||||
|
data=body,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||||
|
streaming = True
|
||||||
|
return StreamingResponse(
|
||||||
|
r.content,
|
||||||
|
status_code=r.status,
|
||||||
|
headers=dict(r.headers),
|
||||||
|
background=BackgroundTask(
|
||||||
|
cleanup_response, response=r, session=session
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response_data = await r.json()
|
||||||
|
return response_data
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
detail = None
|
||||||
|
if r is not None:
|
||||||
|
try:
|
||||||
|
res = await r.json()
|
||||||
|
if "error" in res:
|
||||||
|
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||||
|
except Exception:
|
||||||
|
detail = f"External: {e}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=r.status if r else 500,
|
||||||
|
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if not streaming and session:
|
||||||
|
if r:
|
||||||
|
r.close()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
|
||||||
127
backend/open_webui/utils/embeddings.py
Normal file
127
backend/open_webui/utils/embeddings.py
Normal file
|
|
@ -0,0 +1,127 @@
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
from open_webui.models.models import Models
|
||||||
|
from open_webui.utils.models import check_model_access
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||||
|
|
||||||
|
from open_webui.routers.openai import embeddings as openai_embeddings
|
||||||
|
from open_webui.routers.ollama import embeddings as ollama_embeddings
|
||||||
|
from open_webui.routers.ollama import GenerateEmbeddingsForm
|
||||||
|
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
|
||||||
|
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_embeddings(
|
||||||
|
request: Request,
|
||||||
|
form_data: dict,
|
||||||
|
user: UserModel,
|
||||||
|
bypass_filter: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The FastAPI request context.
|
||||||
|
form_data (dict): The input data sent to the endpoint.
|
||||||
|
user (UserModel): The authenticated user.
|
||||||
|
bypass_filter (bool): If True, disables access filtering (default False).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The embeddings response, following OpenAI API compatibility.
|
||||||
|
"""
|
||||||
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
|
bypass_filter = True
|
||||||
|
|
||||||
|
# Attach extra metadata from request.state if present
|
||||||
|
if hasattr(request.state, "metadata"):
|
||||||
|
if "metadata" not in form_data:
|
||||||
|
form_data["metadata"] = request.state.metadata
|
||||||
|
else:
|
||||||
|
form_data["metadata"] = {
|
||||||
|
**form_data["metadata"],
|
||||||
|
**request.state.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
# If "direct" flag present, use only that model
|
||||||
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
models = {
|
||||||
|
request.state.model["id"]: request.state.model,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
|
model_id = form_data.get("model")
|
||||||
|
if model_id not in models:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
model = models[model_id]
|
||||||
|
|
||||||
|
# Access filtering
|
||||||
|
if not getattr(request.state, "direct", False):
|
||||||
|
if not bypass_filter and user.role == "user":
|
||||||
|
check_model_access(user, model)
|
||||||
|
|
||||||
|
# Arena "meta-model": select a submodel at random
|
||||||
|
if model.get("owned_by") == "arena":
|
||||||
|
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||||
|
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||||
|
if model_ids and filter_mode == "exclude":
|
||||||
|
model_ids = [
|
||||||
|
m["id"]
|
||||||
|
for m in list(models.values())
|
||||||
|
if m.get("owned_by") != "arena" and m["id"] not in model_ids
|
||||||
|
]
|
||||||
|
if isinstance(model_ids, list) and model_ids:
|
||||||
|
selected_model_id = random.choice(model_ids)
|
||||||
|
else:
|
||||||
|
model_ids = [
|
||||||
|
m["id"]
|
||||||
|
for m in list(models.values())
|
||||||
|
if m.get("owned_by") != "arena"
|
||||||
|
]
|
||||||
|
selected_model_id = random.choice(model_ids)
|
||||||
|
inner_form = dict(form_data)
|
||||||
|
inner_form["model"] = selected_model_id
|
||||||
|
response = await generate_embeddings(
|
||||||
|
request, inner_form, user, bypass_filter=True
|
||||||
|
)
|
||||||
|
# Tag which concreted model was chosen
|
||||||
|
if isinstance(response, dict):
|
||||||
|
response = {
|
||||||
|
**response,
|
||||||
|
"selected_model_id": selected_model_id,
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Pipeline/Function models
|
||||||
|
if model.get("pipe"):
|
||||||
|
# The pipeline handler should provide OpenAI-compatible schema
|
||||||
|
return await process_pipeline_inlet_filter(request, form_data, user, models)
|
||||||
|
|
||||||
|
# Ollama backend
|
||||||
|
if model.get("owned_by") == "ollama":
|
||||||
|
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
|
||||||
|
form_obj = GenerateEmbeddingsForm(**ollama_payload)
|
||||||
|
response = await ollama_embeddings(
|
||||||
|
request=request,
|
||||||
|
form_data=form_obj,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
return convert_embedding_response_ollama_to_openai(response)
|
||||||
|
|
||||||
|
# Default: OpenAI or compatible backend
|
||||||
|
return await openai_embeddings(
|
||||||
|
request=request,
|
||||||
|
form_data=form_data,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
@ -329,3 +329,34 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||||
ollama_payload["format"] = format
|
ollama_payload["format"] = format
|
||||||
|
|
||||||
return ollama_payload
|
return ollama_payload
|
||||||
|
|
||||||
|
|
||||||
|
def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Convert an embeddings request payload from OpenAI format to Ollama format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
openai_payload (dict): The original payload designed for OpenAI API usage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A payload compatible with the Ollama API embeddings endpoint.
|
||||||
|
"""
|
||||||
|
ollama_payload = {
|
||||||
|
"model": openai_payload.get("model")
|
||||||
|
}
|
||||||
|
input_value = openai_payload.get("input")
|
||||||
|
|
||||||
|
# Ollama expects 'input' as a list, and 'prompt' as a single string.
|
||||||
|
if isinstance(input_value, list):
|
||||||
|
ollama_payload["input"] = input_value
|
||||||
|
ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
|
||||||
|
else:
|
||||||
|
ollama_payload["input"] = [input_value]
|
||||||
|
ollama_payload["prompt"] = str(input_value)
|
||||||
|
|
||||||
|
# Optionally forward other fields if present
|
||||||
|
for optional_key in ("options", "truncate", "keep_alive"):
|
||||||
|
if optional_key in openai_payload:
|
||||||
|
ollama_payload[optional_key] = openai_payload[optional_key]
|
||||||
|
|
||||||
|
return ollama_payload
|
||||||
|
|
|
||||||
|
|
@ -125,3 +125,55 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||||
yield line
|
yield line
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
def convert_embedding_response_ollama_to_openai(response) -> dict:
|
||||||
|
"""
|
||||||
|
Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (dict): The response from the Ollama API,
|
||||||
|
e.g. {"embedding": [...], "model": "..."}
|
||||||
|
or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Response adapted to OpenAI's embeddings API format.
|
||||||
|
e.g. {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{"object": "embedding", "embedding": [...], "index": 0},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"model": "...",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Ollama batch-style output
|
||||||
|
if isinstance(response, dict) and "embeddings" in response:
|
||||||
|
openai_data = []
|
||||||
|
for i, emb in enumerate(response["embeddings"]):
|
||||||
|
openai_data.append({
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": emb.get("embedding"),
|
||||||
|
"index": emb.get("index", i),
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": openai_data,
|
||||||
|
"model": response.get("model"),
|
||||||
|
}
|
||||||
|
# Ollama single output
|
||||||
|
elif isinstance(response, dict) and "embedding" in response:
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": response["embedding"],
|
||||||
|
"index": 0,
|
||||||
|
}],
|
||||||
|
"model": response.get("model"),
|
||||||
|
}
|
||||||
|
# Already OpenAI-compatible?
|
||||||
|
elif isinstance(response, dict) and "data" in response and isinstance(response["data"], list):
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Fallback: return as is if unrecognized
|
||||||
|
return response
|
||||||
Loading…
Reference in a new issue