diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8527081e84..1e727e3806 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -413,6 +413,7 @@ from open_webui.utils.chat import ( chat_completed as chat_completed_handler, chat_action as chat_action_handler, ) +from open_webui.utils.embeddings import generate_embeddings from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.access_control import has_access @@ -1545,6 +1546,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)): async def get_app_changelog(): return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} +################################## +# Embeddings +################################## + +@app.post("/api/embeddings") +async def embeddings_endpoint( + request: Request, + form_data: dict, + user=Depends(get_verified_user) +): + """ + OpenAI-compatible embeddings endpoint. + + This handler: + - Performs user/model checks and dispatches to the correct backend. + - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider. + + Args: + request (Request): Request context. + form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]}) + user (UserModel): Authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. + """ + # Make sure models are loaded in app state + if not request.app.state.MODELS: + await get_all_models(request, user=user) + # Use generic dispatcher in utils.embeddings + return await generate_embeddings(request, form_data, user) + ############################ # OAuth Login & Callback diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 9c3c393677..486246b640 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -886,6 +886,85 @@ async def generate_chat_completion( r.close() await session.close() +async def embeddings(request: Request, form_data: dict, user): + """ + Calls the embeddings endpoint for OpenAI-compatible providers. + + Args: + request (Request): The FastAPI request context. + form_data (dict): OpenAI-compatible embeddings payload. + user (UserModel): The authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. + """ + idx = 0 + # Prepare payload/body + body = json.dumps(form_data) + # Find correct backend url/key based on model + await get_all_models(request, user=user) + model_id = form_data.get("model") + models = request.app.state.OPENAI_MODELS + if model_id in models: + idx = models[model_id]["urlIdx"] + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + r = None + session = None + streaming = False + try: + session = aiohttp.ClientSession(trust_env=True) + r = await session.request( + method="POST", + url=f"{url}/embeddings", + data=body, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user else {} + ), + }, + ) + r.raise_for_status() + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + response_data = await r.json() + return response_data + except Exception as e: + log.exception(e) + detail = None + if r is not None: + try: + res = await r.json() + if "error" in res: + detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except Exception: + detail = f"External: {e}" + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + finally: + if not streaming and session: + if r: + r.close() + await session.close() @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py new file mode 100644 index 0000000000..c3fae66022 --- /dev/null +++ b/backend/open_webui/utils/embeddings.py @@ -0,0 +1,127 @@ +import random +import logging +import sys + +from fastapi import Request +from open_webui.models.users import UserModel +from open_webui.models.models import Models +from open_webui.utils.models import check_model_access +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL + +from open_webui.routers.openai import embeddings as openai_embeddings +from open_webui.routers.ollama import embeddings as ollama_embeddings +from open_webui.routers.ollama import GenerateEmbeddingsForm +from open_webui.routers.pipelines import process_pipeline_inlet_filter + + +from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama +from open_webui.utils.response import convert_embedding_response_ollama_to_openai + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def generate_embeddings( + request: Request, + form_data: dict, + user: UserModel, + bypass_filter: bool = False, +): + """ + Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc). + + Args: + request (Request): The FastAPI request context. + form_data (dict): The input data sent to the endpoint. + user (UserModel): The authenticated user. + bypass_filter (bool): If True, disables access filtering (default False). + + Returns: + dict: The embeddings response, following OpenAI API compatibility. + """ + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + # Attach extra metadata from request.state if present + if hasattr(request.state, "metadata"): + if "metadata" not in form_data: + form_data["metadata"] = request.state.metadata + else: + form_data["metadata"] = { + **form_data["metadata"], + **request.state.metadata, + } + + # If "direct" flag present, use only that model + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + model_id = form_data.get("model") + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + # Access filtering + if not getattr(request.state, "direct", False): + if not bypass_filter and user.role == "user": + check_model_access(user, model) + + # Arena "meta-model": select a submodel at random + if model.get("owned_by") == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + m["id"] + for m in list(models.values()) + if m.get("owned_by") != "arena" and m["id"] not in model_ids + ] + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + m["id"] + for m in list(models.values()) + if m.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + inner_form = dict(form_data) + inner_form["model"] = selected_model_id + response = await generate_embeddings( + request, inner_form, user, bypass_filter=True + ) + # Tag which concreted model was chosen + if isinstance(response, dict): + response = { + **response, + "selected_model_id": selected_model_id, + } + return response + + # Pipeline/Function models + if model.get("pipe"): + # The pipeline handler should provide OpenAI-compatible schema + return await process_pipeline_inlet_filter(request, form_data, user, models) + + # Ollama backend + if model.get("owned_by") == "ollama": + ollama_payload = convert_embedding_payload_openai_to_ollama(form_data) + form_obj = GenerateEmbeddingsForm(**ollama_payload) + response = await ollama_embeddings( + request=request, + form_data=form_obj, + user=user, + ) + return convert_embedding_response_ollama_to_openai(response) + + # Default: OpenAI or compatible backend + return await openai_embeddings( + request=request, + form_data=form_data, + user=user, + ) \ No newline at end of file diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 02eb0da22b..260b98032a 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -329,3 +329,34 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: ollama_payload["format"] = format return ollama_payload + + +def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict: + """ + Convert an embeddings request payload from OpenAI format to Ollama format. + + Args: + openai_payload (dict): The original payload designed for OpenAI API usage. + + Returns: + dict: A payload compatible with the Ollama API embeddings endpoint. + """ + ollama_payload = { + "model": openai_payload.get("model") + } + input_value = openai_payload.get("input") + + # Ollama expects 'input' as a list, and 'prompt' as a single string. + if isinstance(input_value, list): + ollama_payload["input"] = input_value + ollama_payload["prompt"] = "\n".join(str(x) for x in input_value) + else: + ollama_payload["input"] = [input_value] + ollama_payload["prompt"] = str(input_value) + + # Optionally forward other fields if present + for optional_key in ("options", "truncate", "keep_alive"): + if optional_key in openai_payload: + ollama_payload[optional_key] = openai_payload[optional_key] + + return ollama_payload diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index 8c3f1a58eb..b454325d8a 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -125,3 +125,55 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) yield line yield "data: [DONE]\n\n" + +def convert_embedding_response_ollama_to_openai(response) -> dict: + """ + Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format. + + Args: + response (dict): The response from the Ollama API, + e.g. {"embedding": [...], "model": "..."} + or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."} + + Returns: + dict: Response adapted to OpenAI's embeddings API format. + e.g. { + "object": "list", + "data": [ + {"object": "embedding", "embedding": [...], "index": 0}, + ... + ], + "model": "...", + } + """ + # Ollama batch-style output + if isinstance(response, dict) and "embeddings" in response: + openai_data = [] + for i, emb in enumerate(response["embeddings"]): + openai_data.append({ + "object": "embedding", + "embedding": emb.get("embedding"), + "index": emb.get("index", i), + }) + return { + "object": "list", + "data": openai_data, + "model": response.get("model"), + } + # Ollama single output + elif isinstance(response, dict) and "embedding" in response: + return { + "object": "list", + "data": [{ + "object": "embedding", + "embedding": response["embedding"], + "index": 0, + }], + "model": response.get("model"), + } + # Already OpenAI-compatible? + elif isinstance(response, dict) and "data" in response and isinstance(response["data"], list): + return response + + # Fallback: return as is if unrecognized + return response \ No newline at end of file