为后端 /api/chat/completions 接口写明注释

This commit is contained in:
Gaofeng 2025-11-24 21:50:03 +08:00
parent 88b7ee97aa
commit 46c13d87d2

View file

@ -1435,83 +1435,135 @@ async def chat_completion(
form_data: dict,
user=Depends(get_verified_user),
):
if not request.app.state.MODELS:
await get_all_models(request, user=user)
"""
聊天完成接口 - 处理用户与 AI 模型的对话请求
model_id = form_data.get("model", None)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)
核心功能:
1. 模型验证: 检查模型是否存在及用户访问权限
2. 元数据构建: 提取 chat_id, message_id, session_id 等上下文信息
3. Payload 处理: 通过 process_chat_payload 处理消息工具调用文件等
4. 聊天执行: 调用 chat_completion_handler LLM 交互
5. 响应处理: 通过 process_chat_response 处理流式/非流式响应
6. 异步任务: 如果有 session_id创建后台任务异步执行
Args:
request: FastAPI Request 对象
form_data: 聊天请求数据包含:
- model: 模型 ID
- messages: 对话历史 (OpenAI 格式)
- chat_id: 聊天会话 ID
- id: 消息 ID
- session_id: 会话 ID (用于异步任务)
- tool_ids: 工具 ID 列表
- files: 附加文件列表
- stream: 是否流式响应
user: 已验证的用户对象
Returns:
- 同步模式: 返回 LLM 响应 (流式 StreamingResponse 或完整 JSON)
- 异步模式: 返回 {"status": True, "task_id": "xxx"}
Raises:
HTTPException 400: 模型不存在无访问权限参数错误
HTTPException 404: Chat 不存在
处理流程:
1. 加载所有模型到 app.state.MODELS
2. 验证模型访问权限 (check_model_access)
3. 构建 metadata (包含 user_id, chat_id, tool_ids )
4. 定义 process_chat 内部函数:
- 调用 process_chat_payload (处理 Pipeline/Filter/Tools)
- 调用 chat_completion_handler ( LLM 交互)
- 更新数据库消息记录
- 调用 process_chat_response (处理响应事件发射)
5. 根据是否有 session_id 决定同步/异步执行
"""
# === 1. 初始化阶段:加载模型列表 ===
if not request.app.state.MODELS:
await get_all_models(request, user=user) # 从数据库和后端服务加载所有可用模型
# === 2. 提取请求参数 ===
model_id = form_data.get("model", None) # 用户选择的模型 ID (如 "gpt-4")
model_item = form_data.pop("model_item", {}) # 模型元数据 (包含 direct 标志)
tasks = form_data.pop("background_tasks", None) # 后台任务列表
metadata = {}
try:
# === 3. 模型验证与权限检查 ===
if not model_item.get("direct", False):
# 标准模式:使用平台内置模型
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id] # 从缓存获取模型配置
model_info = Models.get_model_by_id(model_id) # 从数据库获取模型详细信息
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
# 检查用户是否有权限访问该模型
if not BYPASS_MODEL_ACCESS_CONTROL and (
user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL
):
try:
check_model_access(user, model)
check_model_access(user, model) # 检查 RBAC 权限
except Exception as e:
raise e
else:
# Direct 模式:用户直接传入 OpenAI API 等外部模型配置
model = model_item
model_info = None
request.state.direct = True
request.state.direct = True # 标记为直连模式
request.state.model = model
# === 4. 提取模型参数 ===
model_info_params = (
model_info.params.model_dump() if model_info and model_info.params else {}
)
# Chat Params
# 流式响应分块大小 (用于控制 SSE 推送频率)
stream_delta_chunk_size = form_data.get("params", {}).get(
"stream_delta_chunk_size"
)
# 推理标签 (用于标记 AI 的思考过程,如 <think>...</think>)
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
# Model Params
# 模型参数优先级高于请求参数
if model_info_params.get("stream_delta_chunk_size"):
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
if model_info_params.get("reasoning_tags") is not None:
reasoning_tags = model_info_params.get("reasoning_tags")
# === 5. 构建元数据 (metadata) - 贯穿整个处理流程的上下文 ===
metadata = {
"user_id": user.id,
"chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None),
"filter_ids": form_data.pop("filter_ids", []),
"tool_ids": form_data.get("tool_ids", None),
"tool_servers": form_data.pop("tool_servers", None),
"files": form_data.get("files", None),
"features": form_data.get("features", {}),
"variables": form_data.get("variables", {}),
"model": model,
"direct": model_item.get("direct", False),
"chat_id": form_data.pop("chat_id", None), # 聊天会话 ID
"message_id": form_data.pop("id", None), # 当前消息 ID
"session_id": form_data.pop("session_id", None), # WebSocket 会话 ID (异步任务)
"filter_ids": form_data.pop("filter_ids", []), # Pipeline Filter ID 列表
"tool_ids": form_data.get("tool_ids", None), # 工具/函数调用 ID 列表
"tool_servers": form_data.pop("tool_servers", None), # 外部工具服务器配置
"files": form_data.get("files", None), # 用户上传的文件列表
"features": form_data.get("features", {}), # 功能开关 (如 web_search)
"variables": form_data.get("variables", {}), # 模板变量
"model": model, # 模型配置对象
"direct": model_item.get("direct", False), # 是否直连模式
"params": {
"stream_delta_chunk_size": stream_delta_chunk_size,
"reasoning_tags": reasoning_tags,
"function_calling": (
"native"
"native" # 原生函数调用 (如 OpenAI Function Calling)
if (
form_data.get("params", {}).get("function_calling") == "native"
or model_info_params.get("function_calling") == "native"
)
else "default"
else "default" # 默认模式 (通过 Prompt 实现)
),
},
}
# === 6. 权限二次验证:检查用户是否拥有该 chat ===
if metadata.get("chat_id") and (user and user.role != "admin"):
if not metadata["chat_id"].startswith("local:"):
if not metadata["chat_id"].startswith("local:"): # local: 前缀表示临时会话
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
if chat is None:
raise HTTPException(
@ -1519,8 +1571,9 @@ async def chat_completion(
detail=ERROR_MESSAGES.DEFAULT(),
)
request.state.metadata = metadata
form_data["metadata"] = metadata
# === 7. 保存元数据到请求状态和 form_data ===
request.state.metadata = metadata # 供其他中间件/处理器访问
form_data["metadata"] = metadata # 传递给下游处理函数
except Exception as e:
log.debug(f"Error processing chat metadata: {e}")
@ -1529,13 +1582,19 @@ async def chat_completion(
detail=str(e),
)
# === 8. 定义内部处理函数 process_chat ===
async def process_chat(request, form_data, user, metadata, model):
"""处理完整的聊天流程Payload 处理 → LLM 调用 → 响应处理"""
try:
# 8.1 Payload 预处理:执行 Pipeline Filters、工具注入、RAG 检索等
form_data, metadata, events = await process_chat_payload(
request, form_data, user, metadata, model
)
# 8.2 调用 LLM 完成对话 (核心)
response = await chat_completion_handler(request, form_data, user)
# 8.3 更新数据库:保存模型 ID 到消息记录
if metadata.get("chat_id") and metadata.get("message_id"):
try:
if not metadata["chat_id"].startswith("local:"):
@ -1549,23 +1608,28 @@ async def chat_completion(
except:
pass
# 8.4 响应后处理:执行后置 Pipeline、事件发射、任务回调等
return await process_chat_response(
request, response, form_data, user, metadata, model, events, tasks
)
# 8.5 异常处理:取消任务
except asyncio.CancelledError:
log.info("Chat processing was cancelled")
try:
event_emitter = get_event_emitter(metadata)
await event_emitter(
{"type": "chat:tasks:cancel"},
{"type": "chat:tasks:cancel"}, # 通知前端任务已取消
)
except Exception as e:
pass
# 8.6 异常处理:记录错误到数据库并通知前端
except Exception as e:
log.debug(f"Error processing chat payload: {e}")
if metadata.get("chat_id") and metadata.get("message_id"):
# Update the chat message with the error
try:
# 将错误信息保存到消息记录
if not metadata["chat_id"].startswith("local:"):
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
@ -1575,6 +1639,7 @@ async def chat_completion(
},
)
# 通过 WebSocket 发送错误事件到前端
event_emitter = get_event_emitter(metadata)
await event_emitter(
{
@ -1588,21 +1653,25 @@ async def chat_completion(
except:
pass
# 8.7 清理资源:断开 MCP 客户端连接
finally:
try:
if mcp_clients := metadata.get("mcp_clients"):
for client in mcp_clients.values():
await client.disconnect()
await client.disconnect() # 断开 Model Context Protocol 客户端
except Exception as e:
log.debug(f"Error cleaning up: {e}")
pass
# === 9. 决定执行模式:异步任务 vs 同步执行 ===
if (
metadata.get("session_id")
and metadata.get("chat_id")
and metadata.get("message_id")
):
# Asynchronous Chat Processing
# 异步模式:创建后台任务,立即返回 task_id 给前端
# 前端通过 WebSocket 监听任务状态和流式响应
task_id, _ = await create_task(
request.app.state.redis,
process_chat(request, form_data, user, metadata, model),
@ -1610,6 +1679,7 @@ async def chat_completion(
)
return {"status": True, "task_id": task_id}
else:
# 同步模式:直接执行并返回响应 (流式或完整)
return await process_chat(request, form_data, user, metadata, model)