mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 05:45:19 +00:00
为后端 /api/chat/completions 接口写明注释
This commit is contained in:
parent
88b7ee97aa
commit
46c13d87d2
1 changed files with 103 additions and 33 deletions
|
|
@ -1435,83 +1435,135 @@ async def chat_completion(
|
||||||
form_data: dict,
|
form_data: dict,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if not request.app.state.MODELS:
|
"""
|
||||||
await get_all_models(request, user=user)
|
聊天完成接口 - 处理用户与 AI 模型的对话请求
|
||||||
|
|
||||||
model_id = form_data.get("model", None)
|
核心功能:
|
||||||
model_item = form_data.pop("model_item", {})
|
1. 模型验证: 检查模型是否存在及用户访问权限
|
||||||
tasks = form_data.pop("background_tasks", None)
|
2. 元数据构建: 提取 chat_id, message_id, session_id 等上下文信息
|
||||||
|
3. Payload 处理: 通过 process_chat_payload 处理消息、工具调用、文件等
|
||||||
|
4. 聊天执行: 调用 chat_completion_handler 与 LLM 交互
|
||||||
|
5. 响应处理: 通过 process_chat_response 处理流式/非流式响应
|
||||||
|
6. 异步任务: 如果有 session_id,创建后台任务异步执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
form_data: 聊天请求数据,包含:
|
||||||
|
- model: 模型 ID
|
||||||
|
- messages: 对话历史 (OpenAI 格式)
|
||||||
|
- chat_id: 聊天会话 ID
|
||||||
|
- id: 消息 ID
|
||||||
|
- session_id: 会话 ID (用于异步任务)
|
||||||
|
- tool_ids: 工具 ID 列表
|
||||||
|
- files: 附加文件列表
|
||||||
|
- stream: 是否流式响应
|
||||||
|
user: 已验证的用户对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 同步模式: 返回 LLM 响应 (流式 StreamingResponse 或完整 JSON)
|
||||||
|
- 异步模式: 返回 {"status": True, "task_id": "xxx"}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 400: 模型不存在、无访问权限、参数错误
|
||||||
|
HTTPException 404: Chat 不存在
|
||||||
|
|
||||||
|
处理流程:
|
||||||
|
1. 加载所有模型到 app.state.MODELS
|
||||||
|
2. 验证模型访问权限 (check_model_access)
|
||||||
|
3. 构建 metadata (包含 user_id, chat_id, tool_ids 等)
|
||||||
|
4. 定义 process_chat 内部函数:
|
||||||
|
- 调用 process_chat_payload (处理 Pipeline/Filter/Tools)
|
||||||
|
- 调用 chat_completion_handler (与 LLM 交互)
|
||||||
|
- 更新数据库消息记录
|
||||||
|
- 调用 process_chat_response (处理响应、事件发射)
|
||||||
|
5. 根据是否有 session_id 决定同步/异步执行
|
||||||
|
"""
|
||||||
|
# === 1. 初始化阶段:加载模型列表 ===
|
||||||
|
if not request.app.state.MODELS:
|
||||||
|
await get_all_models(request, user=user) # 从数据库和后端服务加载所有可用模型
|
||||||
|
|
||||||
|
# === 2. 提取请求参数 ===
|
||||||
|
model_id = form_data.get("model", None) # 用户选择的模型 ID (如 "gpt-4")
|
||||||
|
model_item = form_data.pop("model_item", {}) # 模型元数据 (包含 direct 标志)
|
||||||
|
tasks = form_data.pop("background_tasks", None) # 后台任务列表
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
try:
|
try:
|
||||||
|
# === 3. 模型验证与权限检查 ===
|
||||||
if not model_item.get("direct", False):
|
if not model_item.get("direct", False):
|
||||||
|
# 标准模式:使用平台内置模型
|
||||||
if model_id not in request.app.state.MODELS:
|
if model_id not in request.app.state.MODELS:
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
||||||
model = request.app.state.MODELS[model_id]
|
model = request.app.state.MODELS[model_id] # 从缓存获取模型配置
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = Models.get_model_by_id(model_id) # 从数据库获取模型详细信息
|
||||||
|
|
||||||
# Check if user has access to the model
|
# 检查用户是否有权限访问该模型
|
||||||
if not BYPASS_MODEL_ACCESS_CONTROL and (
|
if not BYPASS_MODEL_ACCESS_CONTROL and (
|
||||||
user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL
|
user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
check_model_access(user, model)
|
check_model_access(user, model) # 检查 RBAC 权限
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
|
# Direct 模式:用户直接传入 OpenAI API 等外部模型配置
|
||||||
model = model_item
|
model = model_item
|
||||||
model_info = None
|
model_info = None
|
||||||
|
|
||||||
request.state.direct = True
|
request.state.direct = True # 标记为直连模式
|
||||||
request.state.model = model
|
request.state.model = model
|
||||||
|
|
||||||
|
# === 4. 提取模型参数 ===
|
||||||
model_info_params = (
|
model_info_params = (
|
||||||
model_info.params.model_dump() if model_info and model_info.params else {}
|
model_info.params.model_dump() if model_info and model_info.params else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Chat Params
|
# 流式响应分块大小 (用于控制 SSE 推送频率)
|
||||||
stream_delta_chunk_size = form_data.get("params", {}).get(
|
stream_delta_chunk_size = form_data.get("params", {}).get(
|
||||||
"stream_delta_chunk_size"
|
"stream_delta_chunk_size"
|
||||||
)
|
)
|
||||||
|
# 推理标签 (用于标记 AI 的思考过程,如 <think>...</think>)
|
||||||
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
|
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
|
||||||
|
|
||||||
# Model Params
|
# 模型参数优先级高于请求参数
|
||||||
if model_info_params.get("stream_delta_chunk_size"):
|
if model_info_params.get("stream_delta_chunk_size"):
|
||||||
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
||||||
|
|
||||||
if model_info_params.get("reasoning_tags") is not None:
|
if model_info_params.get("reasoning_tags") is not None:
|
||||||
reasoning_tags = model_info_params.get("reasoning_tags")
|
reasoning_tags = model_info_params.get("reasoning_tags")
|
||||||
|
|
||||||
|
# === 5. 构建元数据 (metadata) - 贯穿整个处理流程的上下文 ===
|
||||||
metadata = {
|
metadata = {
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
"chat_id": form_data.pop("chat_id", None),
|
"chat_id": form_data.pop("chat_id", None), # 聊天会话 ID
|
||||||
"message_id": form_data.pop("id", None),
|
"message_id": form_data.pop("id", None), # 当前消息 ID
|
||||||
"session_id": form_data.pop("session_id", None),
|
"session_id": form_data.pop("session_id", None), # WebSocket 会话 ID (异步任务)
|
||||||
"filter_ids": form_data.pop("filter_ids", []),
|
"filter_ids": form_data.pop("filter_ids", []), # Pipeline Filter ID 列表
|
||||||
"tool_ids": form_data.get("tool_ids", None),
|
"tool_ids": form_data.get("tool_ids", None), # 工具/函数调用 ID 列表
|
||||||
"tool_servers": form_data.pop("tool_servers", None),
|
"tool_servers": form_data.pop("tool_servers", None), # 外部工具服务器配置
|
||||||
"files": form_data.get("files", None),
|
"files": form_data.get("files", None), # 用户上传的文件列表
|
||||||
"features": form_data.get("features", {}),
|
"features": form_data.get("features", {}), # 功能开关 (如 web_search)
|
||||||
"variables": form_data.get("variables", {}),
|
"variables": form_data.get("variables", {}), # 模板变量
|
||||||
"model": model,
|
"model": model, # 模型配置对象
|
||||||
"direct": model_item.get("direct", False),
|
"direct": model_item.get("direct", False), # 是否直连模式
|
||||||
"params": {
|
"params": {
|
||||||
"stream_delta_chunk_size": stream_delta_chunk_size,
|
"stream_delta_chunk_size": stream_delta_chunk_size,
|
||||||
"reasoning_tags": reasoning_tags,
|
"reasoning_tags": reasoning_tags,
|
||||||
"function_calling": (
|
"function_calling": (
|
||||||
"native"
|
"native" # 原生函数调用 (如 OpenAI Function Calling)
|
||||||
if (
|
if (
|
||||||
form_data.get("params", {}).get("function_calling") == "native"
|
form_data.get("params", {}).get("function_calling") == "native"
|
||||||
or model_info_params.get("function_calling") == "native"
|
or model_info_params.get("function_calling") == "native"
|
||||||
)
|
)
|
||||||
else "default"
|
else "default" # 默认模式 (通过 Prompt 实现)
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# === 6. 权限二次验证:检查用户是否拥有该 chat ===
|
||||||
if metadata.get("chat_id") and (user and user.role != "admin"):
|
if metadata.get("chat_id") and (user and user.role != "admin"):
|
||||||
if not metadata["chat_id"].startswith("local:"):
|
if not metadata["chat_id"].startswith("local:"): # local: 前缀表示临时会话
|
||||||
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
||||||
if chat is None:
|
if chat is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -1519,8 +1571,9 @@ async def chat_completion(
|
||||||
detail=ERROR_MESSAGES.DEFAULT(),
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
)
|
)
|
||||||
|
|
||||||
request.state.metadata = metadata
|
# === 7. 保存元数据到请求状态和 form_data ===
|
||||||
form_data["metadata"] = metadata
|
request.state.metadata = metadata # 供其他中间件/处理器访问
|
||||||
|
form_data["metadata"] = metadata # 传递给下游处理函数
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error processing chat metadata: {e}")
|
log.debug(f"Error processing chat metadata: {e}")
|
||||||
|
|
@ -1529,13 +1582,19 @@ async def chat_completion(
|
||||||
detail=str(e),
|
detail=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# === 8. 定义内部处理函数 process_chat ===
|
||||||
async def process_chat(request, form_data, user, metadata, model):
|
async def process_chat(request, form_data, user, metadata, model):
|
||||||
|
"""处理完整的聊天流程:Payload 处理 → LLM 调用 → 响应处理"""
|
||||||
try:
|
try:
|
||||||
|
# 8.1 Payload 预处理:执行 Pipeline Filters、工具注入、RAG 检索等
|
||||||
form_data, metadata, events = await process_chat_payload(
|
form_data, metadata, events = await process_chat_payload(
|
||||||
request, form_data, user, metadata, model
|
request, form_data, user, metadata, model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 8.2 调用 LLM 完成对话 (核心)
|
||||||
response = await chat_completion_handler(request, form_data, user)
|
response = await chat_completion_handler(request, form_data, user)
|
||||||
|
|
||||||
|
# 8.3 更新数据库:保存模型 ID 到消息记录
|
||||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||||
try:
|
try:
|
||||||
if not metadata["chat_id"].startswith("local:"):
|
if not metadata["chat_id"].startswith("local:"):
|
||||||
|
|
@ -1549,23 +1608,28 @@ async def chat_completion(
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 8.4 响应后处理:执行后置 Pipeline、事件发射、任务回调等
|
||||||
return await process_chat_response(
|
return await process_chat_response(
|
||||||
request, response, form_data, user, metadata, model, events, tasks
|
request, response, form_data, user, metadata, model, events, tasks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 8.5 异常处理:取消任务
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
log.info("Chat processing was cancelled")
|
log.info("Chat processing was cancelled")
|
||||||
try:
|
try:
|
||||||
event_emitter = get_event_emitter(metadata)
|
event_emitter = get_event_emitter(metadata)
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{"type": "chat:tasks:cancel"},
|
{"type": "chat:tasks:cancel"}, # 通知前端任务已取消
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 8.6 异常处理:记录错误到数据库并通知前端
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error processing chat payload: {e}")
|
log.debug(f"Error processing chat payload: {e}")
|
||||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||||
# Update the chat message with the error
|
|
||||||
try:
|
try:
|
||||||
|
# 将错误信息保存到消息记录
|
||||||
if not metadata["chat_id"].startswith("local:"):
|
if not metadata["chat_id"].startswith("local:"):
|
||||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||||
metadata["chat_id"],
|
metadata["chat_id"],
|
||||||
|
|
@ -1575,6 +1639,7 @@ async def chat_completion(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 通过 WebSocket 发送错误事件到前端
|
||||||
event_emitter = get_event_emitter(metadata)
|
event_emitter = get_event_emitter(metadata)
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
|
|
@ -1588,21 +1653,25 @@ async def chat_completion(
|
||||||
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 8.7 清理资源:断开 MCP 客户端连接
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
if mcp_clients := metadata.get("mcp_clients"):
|
if mcp_clients := metadata.get("mcp_clients"):
|
||||||
for client in mcp_clients.values():
|
for client in mcp_clients.values():
|
||||||
await client.disconnect()
|
await client.disconnect() # 断开 Model Context Protocol 客户端
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error cleaning up: {e}")
|
log.debug(f"Error cleaning up: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# === 9. 决定执行模式:异步任务 vs 同步执行 ===
|
||||||
if (
|
if (
|
||||||
metadata.get("session_id")
|
metadata.get("session_id")
|
||||||
and metadata.get("chat_id")
|
and metadata.get("chat_id")
|
||||||
and metadata.get("message_id")
|
and metadata.get("message_id")
|
||||||
):
|
):
|
||||||
# Asynchronous Chat Processing
|
# 异步模式:创建后台任务,立即返回 task_id 给前端
|
||||||
|
# 前端通过 WebSocket 监听任务状态和流式响应
|
||||||
task_id, _ = await create_task(
|
task_id, _ = await create_task(
|
||||||
request.app.state.redis,
|
request.app.state.redis,
|
||||||
process_chat(request, form_data, user, metadata, model),
|
process_chat(request, form_data, user, metadata, model),
|
||||||
|
|
@ -1610,6 +1679,7 @@ async def chat_completion(
|
||||||
)
|
)
|
||||||
return {"status": True, "task_id": task_id}
|
return {"status": True, "task_id": task_id}
|
||||||
else:
|
else:
|
||||||
|
# 同步模式:直接执行并返回响应 (流式或完整)
|
||||||
return await process_chat(request, form_data, user, metadata, model)
|
return await process_chat(request, form_data, user, metadata, model)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue