mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-14 13:25:20 +00:00
为后端 /api/chat/completions 接口写明注释
This commit is contained in:
parent
88b7ee97aa
commit
46c13d87d2
1 changed files with 103 additions and 33 deletions
|
|
@ -1435,83 +1435,135 @@ async def chat_completion(
|
|||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
"""
|
||||
聊天完成接口 - 处理用户与 AI 模型的对话请求
|
||||
|
||||
model_id = form_data.get("model", None)
|
||||
model_item = form_data.pop("model_item", {})
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
核心功能:
|
||||
1. 模型验证: 检查模型是否存在及用户访问权限
|
||||
2. 元数据构建: 提取 chat_id, message_id, session_id 等上下文信息
|
||||
3. Payload 处理: 通过 process_chat_payload 处理消息、工具调用、文件等
|
||||
4. 聊天执行: 调用 chat_completion_handler 与 LLM 交互
|
||||
5. 响应处理: 通过 process_chat_response 处理流式/非流式响应
|
||||
6. 异步任务: 如果有 session_id,创建后台任务异步执行
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
form_data: 聊天请求数据,包含:
|
||||
- model: 模型 ID
|
||||
- messages: 对话历史 (OpenAI 格式)
|
||||
- chat_id: 聊天会话 ID
|
||||
- id: 消息 ID
|
||||
- session_id: 会话 ID (用于异步任务)
|
||||
- tool_ids: 工具 ID 列表
|
||||
- files: 附加文件列表
|
||||
- stream: 是否流式响应
|
||||
user: 已验证的用户对象
|
||||
|
||||
Returns:
|
||||
- 同步模式: 返回 LLM 响应 (流式 StreamingResponse 或完整 JSON)
|
||||
- 异步模式: 返回 {"status": True, "task_id": "xxx"}
|
||||
|
||||
Raises:
|
||||
HTTPException 400: 模型不存在、无访问权限、参数错误
|
||||
HTTPException 404: Chat 不存在
|
||||
|
||||
处理流程:
|
||||
1. 加载所有模型到 app.state.MODELS
|
||||
2. 验证模型访问权限 (check_model_access)
|
||||
3. 构建 metadata (包含 user_id, chat_id, tool_ids 等)
|
||||
4. 定义 process_chat 内部函数:
|
||||
- 调用 process_chat_payload (处理 Pipeline/Filter/Tools)
|
||||
- 调用 chat_completion_handler (与 LLM 交互)
|
||||
- 更新数据库消息记录
|
||||
- 调用 process_chat_response (处理响应、事件发射)
|
||||
5. 根据是否有 session_id 决定同步/异步执行
|
||||
"""
|
||||
# === 1. 初始化阶段:加载模型列表 ===
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user) # 从数据库和后端服务加载所有可用模型
|
||||
|
||||
# === 2. 提取请求参数 ===
|
||||
model_id = form_data.get("model", None) # 用户选择的模型 ID (如 "gpt-4")
|
||||
model_item = form_data.pop("model_item", {}) # 模型元数据 (包含 direct 标志)
|
||||
tasks = form_data.pop("background_tasks", None) # 后台任务列表
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
# === 3. 模型验证与权限检查 ===
|
||||
if not model_item.get("direct", False):
|
||||
# 标准模式:使用平台内置模型
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
|
||||
model = request.app.state.MODELS[model_id] # 从缓存获取模型配置
|
||||
model_info = Models.get_model_by_id(model_id) # 从数据库获取模型详细信息
|
||||
|
||||
model = request.app.state.MODELS[model_id]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# Check if user has access to the model
|
||||
# 检查用户是否有权限访问该模型
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and (
|
||||
user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL
|
||||
):
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
check_model_access(user, model) # 检查 RBAC 权限
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
# Direct 模式:用户直接传入 OpenAI API 等外部模型配置
|
||||
model = model_item
|
||||
model_info = None
|
||||
|
||||
request.state.direct = True
|
||||
request.state.direct = True # 标记为直连模式
|
||||
request.state.model = model
|
||||
|
||||
# === 4. 提取模型参数 ===
|
||||
model_info_params = (
|
||||
model_info.params.model_dump() if model_info and model_info.params else {}
|
||||
)
|
||||
|
||||
# Chat Params
|
||||
# 流式响应分块大小 (用于控制 SSE 推送频率)
|
||||
stream_delta_chunk_size = form_data.get("params", {}).get(
|
||||
"stream_delta_chunk_size"
|
||||
)
|
||||
# 推理标签 (用于标记 AI 的思考过程,如 <think>...</think>)
|
||||
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
|
||||
|
||||
# Model Params
|
||||
# 模型参数优先级高于请求参数
|
||||
if model_info_params.get("stream_delta_chunk_size"):
|
||||
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
||||
|
||||
if model_info_params.get("reasoning_tags") is not None:
|
||||
reasoning_tags = model_info_params.get("reasoning_tags")
|
||||
|
||||
# === 5. 构建元数据 (metadata) - 贯穿整个处理流程的上下文 ===
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
"chat_id": form_data.pop("chat_id", None),
|
||||
"message_id": form_data.pop("id", None),
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
"filter_ids": form_data.pop("filter_ids", []),
|
||||
"tool_ids": form_data.get("tool_ids", None),
|
||||
"tool_servers": form_data.pop("tool_servers", None),
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", {}),
|
||||
"variables": form_data.get("variables", {}),
|
||||
"model": model,
|
||||
"direct": model_item.get("direct", False),
|
||||
"chat_id": form_data.pop("chat_id", None), # 聊天会话 ID
|
||||
"message_id": form_data.pop("id", None), # 当前消息 ID
|
||||
"session_id": form_data.pop("session_id", None), # WebSocket 会话 ID (异步任务)
|
||||
"filter_ids": form_data.pop("filter_ids", []), # Pipeline Filter ID 列表
|
||||
"tool_ids": form_data.get("tool_ids", None), # 工具/函数调用 ID 列表
|
||||
"tool_servers": form_data.pop("tool_servers", None), # 外部工具服务器配置
|
||||
"files": form_data.get("files", None), # 用户上传的文件列表
|
||||
"features": form_data.get("features", {}), # 功能开关 (如 web_search)
|
||||
"variables": form_data.get("variables", {}), # 模板变量
|
||||
"model": model, # 模型配置对象
|
||||
"direct": model_item.get("direct", False), # 是否直连模式
|
||||
"params": {
|
||||
"stream_delta_chunk_size": stream_delta_chunk_size,
|
||||
"reasoning_tags": reasoning_tags,
|
||||
"function_calling": (
|
||||
"native"
|
||||
"native" # 原生函数调用 (如 OpenAI Function Calling)
|
||||
if (
|
||||
form_data.get("params", {}).get("function_calling") == "native"
|
||||
or model_info_params.get("function_calling") == "native"
|
||||
)
|
||||
else "default"
|
||||
else "default" # 默认模式 (通过 Prompt 实现)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# === 6. 权限二次验证:检查用户是否拥有该 chat ===
|
||||
if metadata.get("chat_id") and (user and user.role != "admin"):
|
||||
if not metadata["chat_id"].startswith("local:"):
|
||||
if not metadata["chat_id"].startswith("local:"): # local: 前缀表示临时会话
|
||||
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
||||
if chat is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -1519,8 +1571,9 @@ async def chat_completion(
|
|||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
request.state.metadata = metadata
|
||||
form_data["metadata"] = metadata
|
||||
# === 7. 保存元数据到请求状态和 form_data ===
|
||||
request.state.metadata = metadata # 供其他中间件/处理器访问
|
||||
form_data["metadata"] = metadata # 传递给下游处理函数
|
||||
|
||||
except Exception as e:
|
||||
log.debug(f"Error processing chat metadata: {e}")
|
||||
|
|
@ -1529,13 +1582,19 @@ async def chat_completion(
|
|||
detail=str(e),
|
||||
)
|
||||
|
||||
# === 8. 定义内部处理函数 process_chat ===
|
||||
async def process_chat(request, form_data, user, metadata, model):
|
||||
"""处理完整的聊天流程:Payload 处理 → LLM 调用 → 响应处理"""
|
||||
try:
|
||||
# 8.1 Payload 预处理:执行 Pipeline Filters、工具注入、RAG 检索等
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, user, metadata, model
|
||||
)
|
||||
|
||||
# 8.2 调用 LLM 完成对话 (核心)
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
|
||||
# 8.3 更新数据库:保存模型 ID 到消息记录
|
||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||
try:
|
||||
if not metadata["chat_id"].startswith("local:"):
|
||||
|
|
@ -1549,23 +1608,28 @@ async def chat_completion(
|
|||
except:
|
||||
pass
|
||||
|
||||
# 8.4 响应后处理:执行后置 Pipeline、事件发射、任务回调等
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
)
|
||||
|
||||
# 8.5 异常处理:取消任务
|
||||
except asyncio.CancelledError:
|
||||
log.info("Chat processing was cancelled")
|
||||
try:
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
await event_emitter(
|
||||
{"type": "chat:tasks:cancel"},
|
||||
{"type": "chat:tasks:cancel"}, # 通知前端任务已取消
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 8.6 异常处理:记录错误到数据库并通知前端
|
||||
except Exception as e:
|
||||
log.debug(f"Error processing chat payload: {e}")
|
||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||
# Update the chat message with the error
|
||||
try:
|
||||
# 将错误信息保存到消息记录
|
||||
if not metadata["chat_id"].startswith("local:"):
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
|
|
@ -1575,6 +1639,7 @@ async def chat_completion(
|
|||
},
|
||||
)
|
||||
|
||||
# 通过 WebSocket 发送错误事件到前端
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
await event_emitter(
|
||||
{
|
||||
|
|
@ -1588,21 +1653,25 @@ async def chat_completion(
|
|||
|
||||
except:
|
||||
pass
|
||||
|
||||
# 8.7 清理资源:断开 MCP 客户端连接
|
||||
finally:
|
||||
try:
|
||||
if mcp_clients := metadata.get("mcp_clients"):
|
||||
for client in mcp_clients.values():
|
||||
await client.disconnect()
|
||||
await client.disconnect() # 断开 Model Context Protocol 客户端
|
||||
except Exception as e:
|
||||
log.debug(f"Error cleaning up: {e}")
|
||||
pass
|
||||
|
||||
# === 9. 决定执行模式:异步任务 vs 同步执行 ===
|
||||
if (
|
||||
metadata.get("session_id")
|
||||
and metadata.get("chat_id")
|
||||
and metadata.get("message_id")
|
||||
):
|
||||
# Asynchronous Chat Processing
|
||||
# 异步模式:创建后台任务,立即返回 task_id 给前端
|
||||
# 前端通过 WebSocket 监听任务状态和流式响应
|
||||
task_id, _ = await create_task(
|
||||
request.app.state.redis,
|
||||
process_chat(request, form_data, user, metadata, model),
|
||||
|
|
@ -1610,6 +1679,7 @@ async def chat_completion(
|
|||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
else:
|
||||
# 同步模式:直接执行并返回响应 (流式或完整)
|
||||
return await process_chat(request, form_data, user, metadata, model)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue