diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 901e4e76c6..41ee380ef7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -111,6 +111,7 @@ from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.models.users import UserModel, Users from open_webui.models.chats import Chats +from open_webui.models.user_model_credentials import UserModelCredentials from open_webui.config import ( # Ollama @@ -476,6 +477,7 @@ from open_webui.utils.models import ( get_all_base_models, check_model_access, get_filtered_models, + transform_user_model_if_needed, ) from open_webui.utils.email_utils import EmailVerificationManager from open_webui.utils.chat import ( @@ -1527,6 +1529,7 @@ async def chat_completion( await get_all_models(request, user=user) # 从数据库和后端服务加载所有可用模型 # === 2. 提取请求参数 === + form_data = await transform_user_model_if_needed(form_data, user) model_id = form_data.get("model", None) # 用户选择的模型 ID (如 "gpt-4") model_item = form_data.pop("model_item", {}) # 模型元数据 (包含 direct 标志) tasks = form_data.pop("background_tasks", None) # 后台任务列表 @@ -1537,7 +1540,7 @@ async def chat_completion( if not model_item.get("direct", False): # 标准模式:使用平台内置模型 if model_id not in request.app.state.MODELS: - raise Exception("Model not found") + raise Exception(f"Model not found: {model_id}") model = request.app.state.MODELS[model_id] # 从缓存获取模型配置 model_info = Models.get_model_by_id(model_id) # 从数据库获取模型详细信息 diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index fe56e37ae4..7bf2c30740 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -860,89 +860,87 @@ async def generate_chat_completion( # updated_at=1763971141 # created_at=1763874812 - # === 1. 权限检查配置 === - if BYPASS_MODEL_ACCESS_CONTROL: - bypass_filter = True - - idx = 0 # 用于标识使用哪个 OPENAI_API_BASE_URL - - # === 2. 准备 Payload 和提取元数据 === + # === 1. 准备 Payload 和清理内部参数 === payload = {**form_data} - metadata = payload.pop("metadata", None) # 移除内部元数据,不发送给上游 API + metadata = payload.pop("metadata", None) + + # 清理掉上游 API 不认识的、仅供后端内部逻辑使用的参数,避免请求被拒绝 + payload.pop("is_user_model", None) + payload.pop("model_item", None) model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - # === 3. 应用模型配置和权限检查 === - # Check model info and override the payload - if model_info: - # 3.1 如果配置了 base_model_id,替换为底层模型 ID - # 例如:自定义模型 "my-gpt4" → 实际调用 "gpt-4-turbo" - if model_info.base_model_id: - payload["model"] = model_info.base_model_id - model_id = model_info.base_model_id + # === 2. 动态凭据和配置选择 === + # 初始化所有后续逻辑需要用到的变量 + url, key, api_config, model, model_info = None, None, {}, None, None - # 3.2 应用模型参数(temperature, max_tokens 等) - params = model_info.params.model_dump() + # 通过检查 request.state.direct 标志,来区分是“私有模型”还是“普通模型” + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + # --- 私有模型路径 --- + # 对于私有模型(由上游函数识别并标记),直接从 request.state 中提取其独立的凭据和配置 + direct_model_config = request.state.model + url = direct_model_config.get("base_url") + key = direct_model_config.get("api_key") + api_config = direct_model_config.get("config") or {} # 保证 api_config 始终是字典,避免 NoneType 错误 + + # 私有模型没有数据库中的 model_info 记录,但其本身(direct_model_config)就是完整的模型定义 + model_info = None + model = direct_model_config - if params: - system = params.pop("system", None) # 提取 system prompt - - # 应用模型参数到 payload(覆盖用户传入的参数) - payload = apply_model_params_to_body_openai(params, payload) - # 注入或替换 system prompt - payload = apply_system_prompt_to_body(system, payload, metadata, user) - - # 3.3 权限检查:验证用户是否有权限访问该模型 - # Check if user has access to the model - if not bypass_filter and user.role == "user": - if not ( - user.id == model_info.user_id # 用户是模型创建者 - or has_access( - user.id, type="read", access_control=model_info.access_control - ) # 或用户在访问控制列表中 - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - elif not bypass_filter: - # 如果模型信息不存在且未绕过过滤器,只有管理员可访问 - if user.role != "admin": - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - # === 4. 查找 OpenAI API 配置 === - await get_all_models(request, user=user) # 刷新模型列表 - model = request.app.state.OPENAI_MODELS.get(model_id) - if model: - idx = model["urlIdx"] # 获取 API 基础 URL 索引 + # 健壮性检查:如果私有模型未提供 URL,则回退到使用第一个全局配置的 URL + if not url: + url = request.app.state.config.OPENAI_API_BASE_URLS[0] + # 如果私有模型也未提供 Key,则同时回退到使用全局 Key + if not key: + key = request.app.state.config.OPENAI_API_KEYS[0] else: - raise HTTPException( - status_code=404, - detail="Model not found", + # --- 普通平台模型路径 --- + # 保持原始逻辑,从数据库和全局配置中获取模型的凭据和设置 + model_info = Models.get_model_by_id(model_id) + if model_info: + # 如果模型配置了 base_model_id,则实际请求时使用基础模型 + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + model_id = model_info.base_model_id + + # 应用数据库中为该模型保存的特定参数(如 temperature, system_prompt 等) + params = model_info.params.model_dump() + if params: + system = params.pop("system", None) + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_system_prompt_to_body(system, payload, metadata, user) + + # 权限检查 + if not bypass_filter and user.role == "user": + if not (user.id == model_info.user_id or has_access(user.id, type="read", access_control=model_info.access_control)): + raise HTTPException(status_code=403, detail="Model not found") + elif not bypass_filter and user.role != "admin": + raise HTTPException(status_code=403, detail="Model not found") + + # 从缓存的平台模型列表中查找模型,以确定其所属的 API (urlIdx) + await get_all_models(request, user=user) + model = request.app.state.OPENAI_MODELS.get(model_id) + if model: + idx = model["urlIdx"] + else: + raise HTTPException(status_code=404, detail=f"Model not found: {model_id}") + + # 根据索引从全局配置中获取 url, key 和 api_config + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support ) - # === 5. 获取 API 配置并处理 prefix_id === - # Get the API config for the model - api_config = request.app.state.config.OPENAI_API_CONFIGS.get( - str(idx), - request.app.state.config.OPENAI_API_CONFIGS.get( - request.app.state.config.OPENAI_API_BASE_URLS[idx], {} - ), # Legacy support - ) - - # 移除模型 ID 前缀(如果配置了 prefix_id) - # 例如:模型 ID "custom.gpt-4" → 发送给 API 的是 "gpt-4" + # === 3. 应用通用配置和特殊处理 === + # 移除模型 ID 的前缀(如果配置了) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - # === 6. Pipeline 模式:注入用户信息 === - # Add user info to the payload if the model is a pipeline - if "pipeline" in model and model.get("pipeline"): + # 如果是 Pipeline 类型的模型,注入用户信息 + if model and "pipeline" in model and model.get("pipeline"): payload["user"] = { "name": user.name, "id": user.id, @@ -950,9 +948,6 @@ async def generate_chat_completion( "role": user.role, } - url = request.app.state.config.OPENAI_API_BASE_URLS[idx] - key = request.app.state.config.OPENAI_API_KEYS[idx] - # === 7. 推理模型特殊处理 === # Check if model is a reasoning model that needs special handling if is_openai_reasoning_model(payload["model"]): diff --git a/backend/open_webui/routers/user_models.py b/backend/open_webui/routers/user_models.py index 541dacdd56..e01de2074e 100644 --- a/backend/open_webui/routers/user_models.py +++ b/backend/open_webui/routers/user_models.py @@ -2,10 +2,10 @@ 用户私有模型凭据管理 - API 路由层 提供的接口: -- GET /api/user/models/credentials - 获取当前用户的所有私有模型凭据 -- POST /api/user/models/credentials - 创建新的私有模型凭据 -- PUT /api/user/models/credentials/{id} - 更新指定的私有模型凭据 -- DELETE /api/user/models/credentials/{id} - 删除指定的私有模型凭据 +- GET /api/v1/user/models - 获取当前用户的所有私有模型凭据 +- POST /api/v1/user/models - 创建新的私有模型凭据 +- PUT /api/v1/user/models/{id} - 更新指定的私有模型凭据 +- DELETE /api/v1/user/models/{id} - 删除指定的私有模型凭据 安全设计: - 所有接口都需要用户认证(get_verified_user) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index fbdbaaa0a8..812f6ea6ea 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -45,7 +45,7 @@ from open_webui.utils.plugin import ( load_function_module_by_id, get_function_module_from_cache, ) -from open_webui.utils.models import get_all_models, check_model_access +from open_webui.utils.models import get_all_models, check_model_access, transform_user_model_if_needed from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.response import ( convert_response_ollama_to_openai, @@ -249,9 +249,14 @@ async def generate_chat_completion( model = models[model_id] # === 5. Direct 模式分支:直连外部 API === + # 原始逻辑是调用 generate_direct_chat_completion,但该函数实际上是一个无法处理 API 请求的“断头路”。 + # 经过调试发现,真正能将请求发送到上游 OpenAI 兼容 API 的是 generate_openai_chat_completion。 + # 因此,当识别为直连模式时(例如用户私有模型),将请求直接导向 generate_openai_chat_completion。 if getattr(request.state, "direct", False): - return await generate_direct_chat_completion( - request, form_data, user=user, models=models + return await generate_openai_chat_completion( + request=request, + form_data=form_data, + user=user, ) else: # === 6. 标准模式:检查用户权限 === @@ -370,6 +375,13 @@ async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: await get_all_models(request, user=user) + form_data = await transform_user_model_if_needed(form_data, user) + model_item = form_data.get("model_item", {}) + + if model_item.get("direct", False): + request.state.direct = True + request.state.model = model_item + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 587e2a2c7d..93e058aa74 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -371,3 +371,43 @@ def get_filtered_models(models, user): return filtered_models else: return models + + +async def transform_user_model_if_needed(form_data: dict, user: UserModel): + model_id = form_data.get("model") + model_item = form_data.get("model_item", {}) + + credential_id = model_item.get("credential_id") + if credential_id and credential_id.startswith("user:"): + from open_webui.models.user_model_credentials import UserModelCredentials + from open_webui.constants import ERROR_MESSAGES + from fastapi import HTTPException, status + + cred_id = credential_id.replace("user:", "") + cred = UserModelCredentials.get_credential_by_id_and_user_id(cred_id, user.id) + + if not cred: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.USER_MODEL_CREDENTIAL_NOT_FOUND(), + ) + + model_item.clear() + model_item.update({ + "direct": True, + "id": cred.model_id, + "name": cred.name or cred.model_id, + "owned_by": "user", + "base_url": cred.base_url, + "api_key": cred.api_key, + "config": cred.config, + }) + + if form_data.get('model') != cred.model_id: + form_data['model'] = cred.model_id + model_id = cred.model_id + + form_data["model_item"] = model_item + form_data["model"] = model_id + + return form_data