diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index adcf521eae..1e517b6a4e 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1435,83 +1435,135 @@ async def chat_completion( form_data: dict, user=Depends(get_verified_user), ): - if not request.app.state.MODELS: - await get_all_models(request, user=user) + """ + 聊天完成接口 - 处理用户与 AI 模型的对话请求 - model_id = form_data.get("model", None) - model_item = form_data.pop("model_item", {}) - tasks = form_data.pop("background_tasks", None) + 核心功能: + 1. 模型验证: 检查模型是否存在及用户访问权限 + 2. 元数据构建: 提取 chat_id, message_id, session_id 等上下文信息 + 3. Payload 处理: 通过 process_chat_payload 处理消息、工具调用、文件等 + 4. 聊天执行: 调用 chat_completion_handler 与 LLM 交互 + 5. 响应处理: 通过 process_chat_response 处理流式/非流式响应 + 6. 异步任务: 如果有 session_id,创建后台任务异步执行 + + Args: + request: FastAPI Request 对象 + form_data: 聊天请求数据,包含: + - model: 模型 ID + - messages: 对话历史 (OpenAI 格式) + - chat_id: 聊天会话 ID + - id: 消息 ID + - session_id: 会话 ID (用于异步任务) + - tool_ids: 工具 ID 列表 + - files: 附加文件列表 + - stream: 是否流式响应 + user: 已验证的用户对象 + + Returns: + - 同步模式: 返回 LLM 响应 (流式 StreamingResponse 或完整 JSON) + - 异步模式: 返回 {"status": True, "task_id": "xxx"} + + Raises: + HTTPException 400: 模型不存在、无访问权限、参数错误 + HTTPException 404: Chat 不存在 + + 处理流程: + 1. 加载所有模型到 app.state.MODELS + 2. 验证模型访问权限 (check_model_access) + 3. 构建 metadata (包含 user_id, chat_id, tool_ids 等) + 4. 定义 process_chat 内部函数: + - 调用 process_chat_payload (处理 Pipeline/Filter/Tools) + - 调用 chat_completion_handler (与 LLM 交互) + - 更新数据库消息记录 + - 调用 process_chat_response (处理响应、事件发射) + 5. 根据是否有 session_id 决定同步/异步执行 + """ + # === 1. 初始化阶段:加载模型列表 === + if not request.app.state.MODELS: + await get_all_models(request, user=user) # 从数据库和后端服务加载所有可用模型 + + # === 2. 提取请求参数 === + model_id = form_data.get("model", None) # 用户选择的模型 ID (如 "gpt-4") + model_item = form_data.pop("model_item", {}) # 模型元数据 (包含 direct 标志) + tasks = form_data.pop("background_tasks", None) # 后台任务列表 metadata = {} try: + # === 3. 模型验证与权限检查 === if not model_item.get("direct", False): + # 标准模式:使用平台内置模型 if model_id not in request.app.state.MODELS: raise Exception("Model not found") + + model = request.app.state.MODELS[model_id] # 从缓存获取模型配置 + model_info = Models.get_model_by_id(model_id) # 从数据库获取模型详细信息 - model = request.app.state.MODELS[model_id] - model_info = Models.get_model_by_id(model_id) - - # Check if user has access to the model + # 检查用户是否有权限访问该模型 if not BYPASS_MODEL_ACCESS_CONTROL and ( user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL ): try: - check_model_access(user, model) + check_model_access(user, model) # 检查 RBAC 权限 except Exception as e: raise e else: + # Direct 模式:用户直接传入 OpenAI API 等外部模型配置 model = model_item model_info = None - request.state.direct = True + request.state.direct = True # 标记为直连模式 request.state.model = model + # === 4. 提取模型参数 === model_info_params = ( model_info.params.model_dump() if model_info and model_info.params else {} ) - # Chat Params + # 流式响应分块大小 (用于控制 SSE 推送频率) stream_delta_chunk_size = form_data.get("params", {}).get( "stream_delta_chunk_size" ) + # 推理标签 (用于标记 AI 的思考过程,如 ...) reasoning_tags = form_data.get("params", {}).get("reasoning_tags") - # Model Params + # 模型参数优先级高于请求参数 if model_info_params.get("stream_delta_chunk_size"): stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size") if model_info_params.get("reasoning_tags") is not None: reasoning_tags = model_info_params.get("reasoning_tags") + # === 5. 构建元数据 (metadata) - 贯穿整个处理流程的上下文 === metadata = { "user_id": user.id, - "chat_id": form_data.pop("chat_id", None), - "message_id": form_data.pop("id", None), - "session_id": form_data.pop("session_id", None), - "filter_ids": form_data.pop("filter_ids", []), - "tool_ids": form_data.get("tool_ids", None), - "tool_servers": form_data.pop("tool_servers", None), - "files": form_data.get("files", None), - "features": form_data.get("features", {}), - "variables": form_data.get("variables", {}), - "model": model, - "direct": model_item.get("direct", False), + "chat_id": form_data.pop("chat_id", None), # 聊天会话 ID + "message_id": form_data.pop("id", None), # 当前消息 ID + "session_id": form_data.pop("session_id", None), # WebSocket 会话 ID (异步任务) + "filter_ids": form_data.pop("filter_ids", []), # Pipeline Filter ID 列表 + "tool_ids": form_data.get("tool_ids", None), # 工具/函数调用 ID 列表 + "tool_servers": form_data.pop("tool_servers", None), # 外部工具服务器配置 + "files": form_data.get("files", None), # 用户上传的文件列表 + "features": form_data.get("features", {}), # 功能开关 (如 web_search) + "variables": form_data.get("variables", {}), # 模板变量 + "model": model, # 模型配置对象 + "direct": model_item.get("direct", False), # 是否直连模式 "params": { "stream_delta_chunk_size": stream_delta_chunk_size, "reasoning_tags": reasoning_tags, "function_calling": ( - "native" + "native" # 原生函数调用 (如 OpenAI Function Calling) if ( form_data.get("params", {}).get("function_calling") == "native" or model_info_params.get("function_calling") == "native" ) - else "default" + else "default" # 默认模式 (通过 Prompt 实现) ), }, } + # === 6. 权限二次验证:检查用户是否拥有该 chat === if metadata.get("chat_id") and (user and user.role != "admin"): - if not metadata["chat_id"].startswith("local:"): + if not metadata["chat_id"].startswith("local:"): # local: 前缀表示临时会话 chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id) if chat is None: raise HTTPException( @@ -1519,8 +1571,9 @@ async def chat_completion( detail=ERROR_MESSAGES.DEFAULT(), ) - request.state.metadata = metadata - form_data["metadata"] = metadata + # === 7. 保存元数据到请求状态和 form_data === + request.state.metadata = metadata # 供其他中间件/处理器访问 + form_data["metadata"] = metadata # 传递给下游处理函数 except Exception as e: log.debug(f"Error processing chat metadata: {e}") @@ -1529,13 +1582,19 @@ async def chat_completion( detail=str(e), ) + # === 8. 定义内部处理函数 process_chat === async def process_chat(request, form_data, user, metadata, model): + """处理完整的聊天流程:Payload 处理 → LLM 调用 → 响应处理""" try: + # 8.1 Payload 预处理:执行 Pipeline Filters、工具注入、RAG 检索等 form_data, metadata, events = await process_chat_payload( request, form_data, user, metadata, model ) + # 8.2 调用 LLM 完成对话 (核心) response = await chat_completion_handler(request, form_data, user) + + # 8.3 更新数据库:保存模型 ID 到消息记录 if metadata.get("chat_id") and metadata.get("message_id"): try: if not metadata["chat_id"].startswith("local:"): @@ -1549,23 +1608,28 @@ async def chat_completion( except: pass + # 8.4 响应后处理:执行后置 Pipeline、事件发射、任务回调等 return await process_chat_response( request, response, form_data, user, metadata, model, events, tasks ) + + # 8.5 异常处理:取消任务 except asyncio.CancelledError: log.info("Chat processing was cancelled") try: event_emitter = get_event_emitter(metadata) await event_emitter( - {"type": "chat:tasks:cancel"}, + {"type": "chat:tasks:cancel"}, # 通知前端任务已取消 ) except Exception as e: pass + + # 8.6 异常处理:记录错误到数据库并通知前端 except Exception as e: log.debug(f"Error processing chat payload: {e}") if metadata.get("chat_id") and metadata.get("message_id"): - # Update the chat message with the error try: + # 将错误信息保存到消息记录 if not metadata["chat_id"].startswith("local:"): Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], @@ -1575,6 +1639,7 @@ async def chat_completion( }, ) + # 通过 WebSocket 发送错误事件到前端 event_emitter = get_event_emitter(metadata) await event_emitter( { @@ -1588,21 +1653,25 @@ async def chat_completion( except: pass + + # 8.7 清理资源:断开 MCP 客户端连接 finally: try: if mcp_clients := metadata.get("mcp_clients"): for client in mcp_clients.values(): - await client.disconnect() + await client.disconnect() # 断开 Model Context Protocol 客户端 except Exception as e: log.debug(f"Error cleaning up: {e}") pass + # === 9. 决定执行模式:异步任务 vs 同步执行 === if ( metadata.get("session_id") and metadata.get("chat_id") and metadata.get("message_id") ): - # Asynchronous Chat Processing + # 异步模式:创建后台任务,立即返回 task_id 给前端 + # 前端通过 WebSocket 监听任务状态和流式响应 task_id, _ = await create_task( request.app.state.redis, process_chat(request, form_data, user, metadata, model), @@ -1610,6 +1679,7 @@ async def chat_completion( ) return {"status": True, "task_id": task_id} else: + # 同步模式:直接执行并返回响应 (流式或完整) return await process_chat(request, form_data, user, metadata, model)