mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 04:45:19 +00:00
fix: 私有模型无法正常回复的问题
This commit is contained in:
parent
6147cc9acf
commit
945c9f7c4d
5 changed files with 131 additions and 81 deletions
|
|
@ -111,6 +111,7 @@ from open_webui.models.functions import Functions
|
|||
from open_webui.models.models import Models
|
||||
from open_webui.models.users import UserModel, Users
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.user_model_credentials import UserModelCredentials
|
||||
|
||||
from open_webui.config import (
|
||||
# Ollama
|
||||
|
|
@ -476,6 +477,7 @@ from open_webui.utils.models import (
|
|||
get_all_base_models,
|
||||
check_model_access,
|
||||
get_filtered_models,
|
||||
transform_user_model_if_needed,
|
||||
)
|
||||
from open_webui.utils.email_utils import EmailVerificationManager
|
||||
from open_webui.utils.chat import (
|
||||
|
|
@ -1527,6 +1529,7 @@ async def chat_completion(
|
|||
await get_all_models(request, user=user) # 从数据库和后端服务加载所有可用模型
|
||||
|
||||
# === 2. 提取请求参数 ===
|
||||
form_data = await transform_user_model_if_needed(form_data, user)
|
||||
model_id = form_data.get("model", None) # 用户选择的模型 ID (如 "gpt-4")
|
||||
model_item = form_data.pop("model_item", {}) # 模型元数据 (包含 direct 标志)
|
||||
tasks = form_data.pop("background_tasks", None) # 后台任务列表
|
||||
|
|
@ -1537,7 +1540,7 @@ async def chat_completion(
|
|||
if not model_item.get("direct", False):
|
||||
# 标准模式:使用平台内置模型
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
raise Exception(f"Model not found: {model_id}")
|
||||
|
||||
model = request.app.state.MODELS[model_id] # 从缓存获取模型配置
|
||||
model_info = Models.get_model_by_id(model_id) # 从数据库获取模型详细信息
|
||||
|
|
|
|||
|
|
@ -860,89 +860,87 @@ async def generate_chat_completion(
|
|||
# updated_at=1763971141
|
||||
# created_at=1763874812
|
||||
|
||||
# === 1. 权限检查配置 ===
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
idx = 0 # 用于标识使用哪个 OPENAI_API_BASE_URL
|
||||
|
||||
# === 2. 准备 Payload 和提取元数据 ===
|
||||
# === 1. 准备 Payload 和清理内部参数 ===
|
||||
payload = {**form_data}
|
||||
metadata = payload.pop("metadata", None) # 移除内部元数据,不发送给上游 API
|
||||
metadata = payload.pop("metadata", None)
|
||||
|
||||
# 清理掉上游 API 不认识的、仅供后端内部逻辑使用的参数,避免请求被拒绝
|
||||
payload.pop("is_user_model", None)
|
||||
payload.pop("model_item", None)
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# === 3. 应用模型配置和权限检查 ===
|
||||
# Check model info and override the payload
|
||||
if model_info:
|
||||
# 3.1 如果配置了 base_model_id,替换为底层模型 ID
|
||||
# 例如:自定义模型 "my-gpt4" → 实际调用 "gpt-4-turbo"
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
model_id = model_info.base_model_id
|
||||
# === 2. 动态凭据和配置选择 ===
|
||||
# 初始化所有后续逻辑需要用到的变量
|
||||
url, key, api_config, model, model_info = None, None, {}, None, None
|
||||
|
||||
# 3.2 应用模型参数(temperature, max_tokens 等)
|
||||
params = model_info.params.model_dump()
|
||||
# 通过检查 request.state.direct 标志,来区分是“私有模型”还是“普通模型”
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
# --- 私有模型路径 ---
|
||||
# 对于私有模型(由上游函数识别并标记),直接从 request.state 中提取其独立的凭据和配置
|
||||
direct_model_config = request.state.model
|
||||
url = direct_model_config.get("base_url")
|
||||
key = direct_model_config.get("api_key")
|
||||
api_config = direct_model_config.get("config") or {} # 保证 api_config 始终是字典,避免 NoneType 错误
|
||||
|
||||
# 私有模型没有数据库中的 model_info 记录,但其本身(direct_model_config)就是完整的模型定义
|
||||
model_info = None
|
||||
model = direct_model_config
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None) # 提取 system prompt
|
||||
|
||||
# 应用模型参数到 payload(覆盖用户传入的参数)
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
# 注入或替换 system prompt
|
||||
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# 3.3 权限检查:验证用户是否有权限访问该模型
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id # 用户是模型创建者
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
) # 或用户在访问控制列表中
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
elif not bypass_filter:
|
||||
# 如果模型信息不存在且未绕过过滤器,只有管理员可访问
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# === 4. 查找 OpenAI API 配置 ===
|
||||
await get_all_models(request, user=user) # 刷新模型列表
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"] # 获取 API 基础 URL 索引
|
||||
# 健壮性检查:如果私有模型未提供 URL,则回退到使用第一个全局配置的 URL
|
||||
if not url:
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[0]
|
||||
# 如果私有模型也未提供 Key,则同时回退到使用全局 Key
|
||||
if not key:
|
||||
key = request.app.state.config.OPENAI_API_KEYS[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model not found",
|
||||
# --- 普通平台模型路径 ---
|
||||
# 保持原始逻辑,从数据库和全局配置中获取模型的凭据和设置
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
if model_info:
|
||||
# 如果模型配置了 base_model_id,则实际请求时使用基础模型
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
model_id = model_info.base_model_id
|
||||
|
||||
# 应用数据库中为该模型保存的特定参数(如 temperature, system_prompt 等)
|
||||
params = model_info.params.model_dump()
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# 权限检查
|
||||
if not bypass_filter and user.role == "user":
|
||||
if not (user.id == model_info.user_id or has_access(user.id, type="read", access_control=model_info.access_control)):
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
elif not bypass_filter and user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
|
||||
# 从缓存的平台模型列表中查找模型,以确定其所属的 API (urlIdx)
|
||||
await get_all_models(request, user=user)
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
|
||||
|
||||
# 根据索引从全局配置中获取 url, key 和 api_config
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
|
||||
# === 5. 获取 API 配置并处理 prefix_id ===
|
||||
# Get the API config for the model
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
# 移除模型 ID 前缀(如果配置了 prefix_id)
|
||||
# 例如:模型 ID "custom.gpt-4" → 发送给 API 的是 "gpt-4"
|
||||
# === 3. 应用通用配置和特殊处理 ===
|
||||
# 移除模型 ID 的前缀(如果配置了)
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
# === 6. Pipeline 模式:注入用户信息 ===
|
||||
# Add user info to the payload if the model is a pipeline
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
# 如果是 Pipeline 类型的模型,注入用户信息
|
||||
if model and "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {
|
||||
"name": user.name,
|
||||
"id": user.id,
|
||||
|
|
@ -950,9 +948,6 @@ async def generate_chat_completion(
|
|||
"role": user.role,
|
||||
}
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# === 7. 推理模型特殊处理 ===
|
||||
# Check if model is a reasoning model that needs special handling
|
||||
if is_openai_reasoning_model(payload["model"]):
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@
|
|||
用户私有模型凭据管理 - API 路由层
|
||||
|
||||
提供的接口:
|
||||
- GET /api/user/models/credentials - 获取当前用户的所有私有模型凭据
|
||||
- POST /api/user/models/credentials - 创建新的私有模型凭据
|
||||
- PUT /api/user/models/credentials/{id} - 更新指定的私有模型凭据
|
||||
- DELETE /api/user/models/credentials/{id} - 删除指定的私有模型凭据
|
||||
- GET /api/v1/user/models - 获取当前用户的所有私有模型凭据
|
||||
- POST /api/v1/user/models - 创建新的私有模型凭据
|
||||
- PUT /api/v1/user/models/{id} - 更新指定的私有模型凭据
|
||||
- DELETE /api/v1/user/models/{id} - 删除指定的私有模型凭据
|
||||
|
||||
安全设计:
|
||||
- 所有接口都需要用户认证(get_verified_user)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from open_webui.utils.plugin import (
|
|||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.models import get_all_models, check_model_access
|
||||
from open_webui.utils.models import get_all_models, check_model_access, transform_user_model_if_needed
|
||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||
from open_webui.utils.response import (
|
||||
convert_response_ollama_to_openai,
|
||||
|
|
@ -249,9 +249,14 @@ async def generate_chat_completion(
|
|||
model = models[model_id]
|
||||
|
||||
# === 5. Direct 模式分支:直连外部 API ===
|
||||
# 原始逻辑是调用 generate_direct_chat_completion,但该函数实际上是一个无法处理 API 请求的“断头路”。
|
||||
# 经过调试发现,真正能将请求发送到上游 OpenAI 兼容 API 的是 generate_openai_chat_completion。
|
||||
# 因此,当识别为直连模式时(例如用户私有模型),将请求直接导向 generate_openai_chat_completion。
|
||||
if getattr(request.state, "direct", False):
|
||||
return await generate_direct_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
return await generate_openai_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
# === 6. 标准模式:检查用户权限 ===
|
||||
|
|
@ -370,6 +375,13 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
form_data = await transform_user_model_if_needed(form_data, user)
|
||||
model_item = form_data.get("model_item", {})
|
||||
|
||||
if model_item.get("direct", False):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
|
|
|
|||
|
|
@ -371,3 +371,43 @@ def get_filtered_models(models, user):
|
|||
return filtered_models
|
||||
else:
|
||||
return models
|
||||
|
||||
|
||||
async def transform_user_model_if_needed(form_data: dict, user: UserModel):
|
||||
model_id = form_data.get("model")
|
||||
model_item = form_data.get("model_item", {})
|
||||
|
||||
credential_id = model_item.get("credential_id")
|
||||
if credential_id and credential_id.startswith("user:"):
|
||||
from open_webui.models.user_model_credentials import UserModelCredentials
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
cred_id = credential_id.replace("user:", "")
|
||||
cred = UserModelCredentials.get_credential_by_id_and_user_id(cred_id, user.id)
|
||||
|
||||
if not cred:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.USER_MODEL_CREDENTIAL_NOT_FOUND(),
|
||||
)
|
||||
|
||||
model_item.clear()
|
||||
model_item.update({
|
||||
"direct": True,
|
||||
"id": cred.model_id,
|
||||
"name": cred.name or cred.model_id,
|
||||
"owned_by": "user",
|
||||
"base_url": cred.base_url,
|
||||
"api_key": cred.api_key,
|
||||
"config": cred.config,
|
||||
})
|
||||
|
||||
if form_data.get('model') != cred.model_id:
|
||||
form_data['model'] = cred.model_id
|
||||
model_id = cred.model_id
|
||||
|
||||
form_data["model_item"] = model_item
|
||||
form_data["model"] = model_id
|
||||
|
||||
return form_data
|
||||
|
|
|
|||
Loading…
Reference in a new issue