From 984bb5888e370aeb54c8a464a4f944843d8e4f4c Mon Sep 17 00:00:00 2001 From: Gaofeng Date: Tue, 25 Nov 2025 17:24:31 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BA=E5=90=8E=E7=AB=AF=E8=8B=A5=E5=B9=B2?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E8=A1=A5=E5=85=A8=E7=9B=B8=E5=85=B3=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E6=B3=A8=E9=87=8A=EF=BC=9B=E5=8A=A0=E5=85=A5=20backen?= =?UTF-8?q?d/open=5Fwebui/memory/cross=5Fwindow=5Fmemory.py:last=5Fprocess?= =?UTF-8?q?=5Fpayload=20=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/open_webui/main.py | 71 ++++- backend/open_webui/memory/__init__.py | 0 .../open_webui/memory/cross_window_memory.py | 21 ++ backend/open_webui/routers/openai.py | 120 ++++++- backend/open_webui/utils/chat.py | 208 +++++++----- backend/open_webui/utils/middleware.py | 298 ++++++++++++++---- backend/open_webui/utils/misc.py | 39 +++ 7 files changed, 610 insertions(+), 147 deletions(-) create mode 100644 backend/open_webui/memory/__init__.py create mode 100644 backend/open_webui/memory/cross_window_memory.py diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 1e517b6a4e..f3af7a3907 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -469,6 +469,7 @@ from open_webui.utils.chat import ( chat_completed as chat_completed_handler, chat_action as chat_action_handler, ) +from open_webui.utils.misc import get_message_list from open_webui.utils.embeddings import generate_embeddings from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.access_control import has_access @@ -1592,7 +1593,7 @@ async def chat_completion( ) # 8.2 调用 LLM 完成对话 (核心) - response = await chat_completion_handler(request, form_data, user) + response = await chat_completion_handler(request, form_data, user, chatting_completion = True) # 8.3 更新数据库:保存模型 ID 到消息记录 if metadata.get("chat_id") and metadata.get("message_id"): @@ -1688,6 +1689,74 @@ generate_chat_completions = chat_completion generate_chat_completion = chat_completion +@app.post("/api/chat/tutorial") +async def chat_completion_tutorial( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + """ + 最小可正常工作的示例:复用核心聊天链路,支持直连/历史消息/流式返回。 + + - 入参与 /api/chat/completions 基本兼容,但不做额外后台任务。 + - 若提供 chat_id/message_id,会读取历史消息并可触发 WS 事件;否则直接返回 SSE。 + """ + # 直连兼容 + model_item = form_data.pop("model_item", {}) + if model_item.get("direct"): + request.state.direct = True + request.state.model = model_item + + # 确保模型可用 + if not request.app.state.MODELS: + await get_all_models(request, user=user) + + model_id = form_data.get("model") or next(iter(request.app.state.MODELS.keys())) + if model_id not in request.app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Model not found" + ) + model = request.app.state.MODELS[model_id] + if user.role == "user": + check_model_access(user, model) + + # 取消息历史:优先用传入 messages,否则按 chat_id 取全量 + messages = form_data.get("messages") or [] + chat_id = form_data.get("chat_id") + message_id = form_data.get("id") or form_data.get("message_id") + if not messages and chat_id: + messages_map = Chats.get_messages_map_by_chat_id(chat_id) + if messages_map: + if message_id and message_id in messages_map: + messages = get_message_list(messages_map, message_id) + else: + messages = list(messages_map.values()) + + # 准备 metadata,便于复用后续事件/DB 逻辑 + metadata = { + "user_id": user.id, + "chat_id": chat_id, + "message_id": message_id, + "session_id": form_data.get("session_id"), + "model": model, + "direct": model_item.get("direct", False), + "params": {}, + } + request.state.metadata = metadata + form_data["metadata"] = metadata + + # 核心链路:预处理 -> 调模型 -> 流式/非流式响应处理 + form_data["model"] = model_id + form_data["messages"] = messages + form_data["stream"] = True # 强制流式,便于示例 + + form_data, metadata, events = await process_chat_payload( + request, form_data, user, metadata, model + ) + response = await chat_completion_handler(request, form_data, user) + return await process_chat_response( + request, response, form_data, user, metadata, model, events, tasks=None + ) + + @app.post("/api/chat/completed") async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) diff --git a/backend/open_webui/memory/__init__.py b/backend/open_webui/memory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/open_webui/memory/cross_window_memory.py b/backend/open_webui/memory/cross_window_memory.py new file mode 100644 index 0000000000..c8e82521e8 --- /dev/null +++ b/backend/open_webui/memory/cross_window_memory.py @@ -0,0 +1,21 @@ + +from typing import List, Dict + + +def last_process_payload( + user_id: str, + session_id: str, + messages: List[Dict], +): + """ + 对调用 LLM API 前的上下文信息 进行加工。 + + Args: + user_id (str): 用户的唯一 ID。 + session_id (str): 该用户本次对话/会话的 ID。 + messages (List[Dict]): 该用户在该对话下的聊天消息列表, + 形如 {"role": "system|user|assistant", "content": "...", "timestamp": 0}。 + """ + print("user_id:", user_id) + print("session_id:", session_id) + print("messages:", messages) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 47b6ac601a..90b3f7b171 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -34,6 +34,8 @@ from open_webui.env import ( BYPASS_MODEL_ACCESS_CONTROL, ) from open_webui.models.users import UserModel +from open_webui.memory.cross_window_memory import last_process_payload +from open_webui.utils.misc import extract_timestamped_messages from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS @@ -807,61 +809,123 @@ async def generate_chat_completion( form_data: dict, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, + chatting_completion: bool = False ): + """ + OpenAI 兼容的聊天完成端点 - 直接转发请求到 OpenAI API 或兼容服务 + + 这是 OpenAI router 中的底层 API 调用函数,负责: + 1. 应用模型配置(base_model_id, system prompt, 参数覆盖) + 2. 验证用户权限(模型访问控制) + 3. 处理 Azure OpenAI 特殊格式转换 + 4. 处理推理模型(reasoning model)特殊逻辑 + 5. 转发 HTTP 请求到上游 API(支持流式和非流式) + + Args: + request: FastAPI Request 对象 + form_data: OpenAI 格式的聊天请求 + - model: 模型 ID + - messages: 消息列表 + - stream: 是否流式响应 + - temperature, max_tokens 等参数 + user: 已验证的用户对象 + bypass_filter: 是否绕过权限检查 + + Returns: + - 流式: StreamingResponse (SSE) + - 非流式: dict (OpenAI JSON 格式) + + Raises: + HTTPException 403: 无权限访问模型 + HTTPException 404: 模型不存在 + HTTPException 500: 上游 API 连接失败 + """ + + # print("user:", user) + # user: + # id='55f85fb0-4aca-48bc-aea1-afce50ac989e' + # name='gaofeng1' + # email='h.summit1628935449@gmail.com' + # username=None + # role='user' + # profile_image_url='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAEa0lEQVR4Xu2aW0hUQRzG/7vr7noLSZGgIIMsobAgjRKCrhSJD0EXgpSCopcK6qmCHgyC6Kmo6CGMiAgqeqgQ6SG7EYhoQQpdSCKppBARMXUv7m47C7vr2ZX2XHa378S3b8uZmfPN95tvzpw54xhvK48IfzAOOAgEhkVMCIFg8SAQMB4EQiBoDoDp4TOEQMAcAJPDhBAImANgcpgQAgFzAEwOE0IgYA6AyWFCCATMATA5TAiBgDkAJocJIRAwB8DkMCEEAuYAmBwmhEDAHACTw4QQCJgDYHKYEAIBcwBMDhNCIGAOgMlhQggEzAEwOUwIgVh3wF3TIq55a8VZvlwchRXiKJmvaTQyMSQR34iEfnVLcOC+hIffWL9pnlqwVUK8DRekoKoxDUAmr2Jg+q/K9GBHpqL//LotgDgr66Ro/TVxlFWbNyw0JcFPd8TfddJ8G3moCQ9ETU/ehvMirqI0O8Ij/RIeH5TQcG/imsNTJs6yJeKqXDVrkoLv26ChQANR01PhxutaGNGRPv2tUwJ9lzM+G9wrjolnaXNasvyvj0fTcjsP4934LaCBlOzq1prpHxV/z1nDZhY1dUQXAWsS7kTGBmTiQfK/cdtyVwMWiHqAu5cdSvY8mgx/12nDMOINlOzt10xhvqctkA95WCClzQMi3rkJIIF3FyXQe8700PTUnxHPyhOJ+qEfz2TqyW7T7eWqIiSQVPOyNcWolMTfT0JDL5kQvaMqdc5HXxnp7ZeecpAJSZ2uJh9vzbii0tNZO5SBAxJb6m5JLknVNsjE3Vo7eJkVjXBA1LuDd3VronPq5W/y4YasdNYOjcADmf7aLr7O/XbwMisa/ysgpQdHDJny+0aFofL5KEwg+XDZwD3ggait86n2Rl1dYkJ02WSsUGx3d90lU/tOakHwt1/Bwu2aPS1OWTrZlB74rtnhzZZxhZtvScGipoSKbLWrs1u6isFNWUp18Y4X4qxIvntY3ceKO0EgusZEeqHUnd5s7WURiEkgqlrq9om/p1WCfVcstChCIBbsm+17iO/5YUs7tARiAYiqmvbF0MJHqhjgmn05WSxY7KamOuRDPa5QnTYp3nZP86FKXdN7rEfVd1fvEffinWltqHa4yjIxlNTurxrdqYfhVFNqJzg0/FbCY58lEhhLtO6qrBfnnCrNSk1za3VQ4ssj8b06YkJRbqtAJ2RmUrx1p8S1YJNlN/Smy/KNTDZgCyDxvqm0uGuPat62dfU7mojQzy4JfrhpaVGg614WC9kKyMzEqGeDenl0qrO9qScao8eFYt/ORz/GDtFZXS5b9NhQdVsCMdRDmxUmEDBgBEIgYA6AyWFCCATMATA5TAiBgDkAJocJIRAwB8DkMCEEAuYAmBwmhEDAHACTw4QQCJgDYHKYEAIBcwBMDhNCIGAOgMlhQggEzAEwOUwIgYA5ACaHCSEQMAfA5DAhBALmAJgcJoRAwBwAk8OEEAiYA2BymBACAXMATA4TQiBgDoDJYUIIBMwBMDl/AP79PANEXNbbAAAAAElFTkSuQmCC' + # bio=None + # gender=None + # date_of_birth=None + # info=None + # settings=UserSettings(ui={'memory': True}) + # api_key=None + # oauth_sub=None + # last_active_at=1763997832 + # updated_at=1763971141 + # created_at=1763874812 + + # === 1. 权限检查配置 === if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - idx = 0 + idx = 0 # 用于标识使用哪个 OPENAI_API_BASE_URL + # === 2. 准备 Payload 和提取元数据 === payload = {**form_data} - metadata = payload.pop("metadata", None) + metadata = payload.pop("metadata", None) # 移除内部元数据,不发送给上游 API model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) + # === 3. 应用模型配置和权限检查 === # Check model info and override the payload if model_info: + # 3.1 如果配置了 base_model_id,替换为底层模型 ID + # 例如:自定义模型 "my-gpt4" → 实际调用 "gpt-4-turbo" if model_info.base_model_id: payload["model"] = model_info.base_model_id model_id = model_info.base_model_id + # 3.2 应用模型参数(temperature, max_tokens 等) params = model_info.params.model_dump() if params: - system = params.pop("system", None) + system = params.pop("system", None) # 提取 system prompt + # 应用模型参数到 payload(覆盖用户传入的参数) payload = apply_model_params_to_body_openai(params, payload) + # 注入或替换 system prompt payload = apply_system_prompt_to_body(system, payload, metadata, user) + # 3.3 权限检查:验证用户是否有权限访问该模型 # Check if user has access to the model if not bypass_filter and user.role == "user": if not ( - user.id == model_info.user_id + user.id == model_info.user_id # 用户是模型创建者 or has_access( user.id, type="read", access_control=model_info.access_control - ) + ) # 或用户在访问控制列表中 ): raise HTTPException( status_code=403, detail="Model not found", ) elif not bypass_filter: + # 如果模型信息不存在且未绕过过滤器,只有管理员可访问 if user.role != "admin": raise HTTPException( status_code=403, detail="Model not found", ) - await get_all_models(request, user=user) + # === 4. 查找 OpenAI API 配置 === + await get_all_models(request, user=user) # 刷新模型列表 model = request.app.state.OPENAI_MODELS.get(model_id) if model: - idx = model["urlIdx"] + idx = model["urlIdx"] # 获取 API 基础 URL 索引 else: raise HTTPException( status_code=404, detail="Model not found", ) + # === 5. 获取 API 配置并处理 prefix_id === # Get the API config for the model api_config = request.app.state.config.OPENAI_API_CONFIGS.get( str(idx), @@ -870,10 +934,13 @@ async def generate_chat_completion( ), # Legacy support ) + # 移除模型 ID 前缀(如果配置了 prefix_id) + # 例如:模型 ID "custom.gpt-4" → 发送给 API 的是 "gpt-4" prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + # === 6. Pipeline 模式:注入用户信息 === # Add user info to the payload if the model is a pipeline if "pipeline" in model and model.get("pipeline"): payload["user"] = { @@ -886,32 +953,40 @@ async def generate_chat_completion( url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] + # === 7. 推理模型特殊处理 === # Check if model is a reasoning model that needs special handling if is_openai_reasoning_model(payload["model"]): + # 推理模型(如 o1)使用 max_completion_tokens 而非 max_tokens payload = openai_reasoning_model_handler(payload) elif "api.openai.com" not in url: + # 非 OpenAI 官方 API:向后兼容,将 max_completion_tokens 转为 max_tokens # Remove "max_completion_tokens" from the payload for backward compatibility if "max_completion_tokens" in payload: payload["max_tokens"] = payload["max_completion_tokens"] del payload["max_completion_tokens"] + # 避免同时存在 max_tokens 和 max_completion_tokens if "max_tokens" in payload and "max_completion_tokens" in payload: del payload["max_tokens"] + # === 8. 转换 logit_bias 格式 === # Convert the modified body back to JSON if "logit_bias" in payload: payload["logit_bias"] = json.loads( convert_logit_bias_input_to_json(payload["logit_bias"]) ) + # === 9. 准备请求头和 Cookies === headers, cookies = await get_headers_and_cookies( request, url, key, api_config, metadata, user=user ) + # === 10. Azure OpenAI 特殊处理 === if api_config.get("azure", False): api_version = api_config.get("api_version", "2023-03-15-preview") request_url, payload = convert_to_azure_payload(url, payload, api_version) + # 只有在非 Azure Entra ID 认证时才设置 api-key header # Only set api-key header if not using Azure Entra ID authentication auth_type = api_config.get("auth_type", "bearer") if auth_type not in ("azure_ad", "microsoft_entra_id"): @@ -920,16 +995,34 @@ async def generate_chat_completion( headers["api-version"] = api_version request_url = f"{request_url}/chat/completions?api-version={api_version}" else: + # 标准 OpenAI 兼容 API request_url = f"{url}/chat/completions" - payload = json.dumps(payload) + if chatting_completion: + try: + # 可选钩子:在发送到上游前记录/审计 payload,需自行实现 last_process_payload + log.debug( + f"chatting_completion hook user={user.id} chat_id={metadata.get('chat_id')} model={payload.get('model')}" + ) + last_process_payload( + user_id = user.id, + session_id = metadata.get("chat_id"), + messages = extract_timestamped_messages(payload.get("messages", [])), + ) + except Exception as e: + log.debug(f"chatting_completion 钩子执行失败: {e}") + + payload = json.dumps(payload) # 序列化为 JSON 字符串 + + # === 11. 初始化请求状态变量 === r = None session = None streaming = False response = None try: + # === 12. 发起 HTTP 请求到上游 API === session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) @@ -943,8 +1036,10 @@ async def generate_chat_completion( ssl=AIOHTTP_CLIENT_SESSION_SSL, ) + # === 13. 处理响应 === # Check if response is SSE if "text/event-stream" in r.headers.get("Content-Type", ""): + # 流式响应:直接转发 SSE 流 streaming = True return StreamingResponse( r.content, @@ -955,19 +1050,22 @@ async def generate_chat_completion( ), ) else: + # 非流式响应:解析 JSON try: response = await r.json() except Exception as e: log.error(e) - response = await r.text() + response = await r.text() # 如果 JSON 解析失败,返回纯文本 + # 处理错误响应 if r.status >= 400: if isinstance(response, (dict, list)): return JSONResponse(status_code=r.status, content=response) else: return PlainTextResponse(status_code=r.status, content=response) - return response + return response # 成功响应 + except Exception as e: log.exception(e) @@ -976,6 +1074,8 @@ async def generate_chat_completion( detail="CyberLover: Server Connection Error", ) finally: + # === 14. 清理资源 === + # 非流式响应需要手动关闭连接(流式响应在 BackgroundTask 中处理) if not streaming: await cleanup_response(r, session) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 8b6a0b9da2..31f617126a 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -167,124 +167,182 @@ async def generate_chat_completion( form_data: dict, user: Any, bypass_filter: bool = False, + chatting_completion: bool = False ): - log.debug(f"generate_chat_completion: {form_data}") - if BYPASS_MODEL_ACCESS_CONTROL: - bypass_filter = True + """ + 聊天完成生成函数 - 根据模型类型分发到不同的底层 API 处理器 + 这是聊天完成的核心路由函数,负责: + 1. 验证模型存在性和用户权限 + 2. 处理 Direct 模式(直连外部 API) + 3. 处理 Arena 模式(随机选择模型进行对比) + 4. 根据模型类型分发到对应处理器: + - Pipe: Pipeline 插件函数 + - Ollama: Ollama 本地模型 + - OpenAI: OpenAI 兼容 API (含 Claude, Gemini 等) + + Args: + request: FastAPI Request 对象 + form_data: OpenAI 格式的聊天请求数据 + user: 用户对象 + bypass_filter: 是否绕过权限和 Pipeline Filter 检查 + + Returns: + - 流式: StreamingResponse (SSE 格式) + - 非流式: dict (OpenAI 兼容格式) + + Raises: + Exception: 模型不存在或无权限访问 + """ + log.debug(f"generate_chat_completion: {form_data}") + + # === 1. 权限检查配置 === + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True # 全局配置:绕过所有权限检查 + + # === 2. 合并元数据 === + # 从 request.state.metadata 获取上游传递的元数据(chat_id, user_id 等) if hasattr(request.state, "metadata"): if "metadata" not in form_data: form_data["metadata"] = request.state.metadata else: + # 合并,request.state.metadata 优先级更高 form_data["metadata"] = { **form_data["metadata"], **request.state.metadata, } + # === 3. 确定模型列表来源 === if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + # Direct 模式:用户直接提供外部 API 配置(如 OpenAI API Key) models = { request.state.model["id"]: request.state.model, } log.debug(f"direct connection to model: {models}") else: + # 标准模式:使用平台内置模型列表 models = request.app.state.MODELS + # === 4. 验证模型存在性 === model_id = form_data["model"] if model_id not in models: raise Exception("Model not found") model = models[model_id] + # === 5. Direct 模式分支:直连外部 API === if getattr(request.state, "direct", False): return await generate_direct_chat_completion( request, form_data, user=user, models=models ) else: - # Check if user has access to the model + # === 6. 标准模式:检查用户权限 === if not bypass_filter and user.role == "user": try: - check_model_access(user, model) + check_model_access(user, model) # 验证 RBAC 权限 except Exception as e: raise e - if model.get("owned_by") == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] + # === 7. Arena 模式:随机选择模型进行盲测对比 === + if False: + if model.get("owned_by") == "arena": + # 获取候选模型列表 + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) + # 如果是排除模式,反选模型列表 + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in list(request.app.state.MODELS.values()) + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] - form_data["model"] = selected_model_id + # 随机选择一个模型 + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + # 未指定则从所有非 Arena 模型中随机选择 + model_ids = [ + model["id"] + for model in list(request.app.state.MODELS.values()) + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) - if form_data.get("stream") == True: + # 替换模型 ID + form_data["model"] = selected_model_id - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk + # 流式响应:在首个 chunk 中注入 selected_model_id + if form_data.get("stream") == True: - response = await generate_chat_completion( - request, form_data, user, bypass_filter=True + async def stream_wrapper(stream): + """在流式响应前添加选中的模型 ID""" + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + # 递归调用自身,绕过 Arena 逻辑 + response = await generate_chat_completion( + request, form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), + media_type="text/event-stream", + background=response.background, + ) + else: + # 非流式响应:直接在结果中添加 selected_model_id + return { + **( + await generate_chat_completion( + request, form_data, user, bypass_filter=True + ) + ), + "selected_model_id": selected_model_id, + } + + # === 8. Pipeline 模式:调用自定义 Python 函数 === + if False: + if model.get("pipe"): + return await generate_function_chat_completion( + request, form_data, user=user, models=models ) - return StreamingResponse( - stream_wrapper(response.body_iterator), - media_type="text/event-stream", - background=response.background, - ) - else: - return { - **( - await generate_chat_completion( - request, form_data, user, bypass_filter=True - ) - ), - "selected_model_id": selected_model_id, - } - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( - request, form_data, user=user, models=models - ) - if model.get("owned_by") == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - response = await generate_ollama_chat_completion( - request=request, - form_data=form_data, - user=user, - bypass_filter=bypass_filter, - ) - if form_data.get("stream"): - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - background=response.background, + # === 9. Ollama 模式:调用本地 Ollama 服务 === + if False: + if model.get("owned_by") == "ollama": + # 转换 OpenAI 格式 → Ollama 格式 + form_data = convert_payload_openai_to_ollama(form_data) + response = await generate_ollama_chat_completion( + request=request, + form_data=form_data, + user=user, + bypass_filter=bypass_filter, ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion( - request=request, - form_data=form_data, - user=user, - bypass_filter=bypass_filter, - ) + + # 流式响应:转换 Ollama SSE → OpenAI SSE + if form_data.get("stream"): + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + background=response.background, + ) + else: + # 非流式响应:转换 Ollama JSON → OpenAI JSON + return convert_response_ollama_to_openai(response) + + # === 10. OpenAI 兼容模式:调用 OpenAI API 或兼容服务 === + # >>> + return await generate_openai_chat_completion( + request=request, + form_data=form_data, + user=user, + bypass_filter=bypass_filter, + chatting_completion = chatting_completion + ) chat_completion = generate_chat_completion diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index dd42612eee..615dc89afd 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -994,25 +994,52 @@ def apply_params_to_form_data(form_data, model): async def process_chat_payload(request, form_data, user, metadata, model): - # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation - # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling - # -> Chat Files + """ + 处理聊天请求的 Payload - 执行 Pipeline、Filter、功能增强和工具注入 + 这是聊天请求预处理的核心函数,按以下顺序执行: + 1. Pipeline Inlet (管道入口) - 自定义 Python 插件预处理 + 2. Filter Inlet (过滤器入口) - 函数过滤器预处理 + 3. Chat Memory (记忆) - 注入历史对话记忆 + 4. Chat Web Search (网页搜索) - 执行网络搜索并注入结果 + 5. Chat Image Generation (图像生成) - 处理图像生成请求 + 6. Chat Code Interpreter (代码解释器) - 注入代码执行提示词 + 7. Chat Tools Function Calling (工具调用) - 处理函数/工具调用 + 8. Chat Files (文件处理) - 处理上传文件、知识库文件、RAG 检索 + + Args: + request: FastAPI Request 对象 + form_data: OpenAI 格式的聊天请求数据 + user: 用户对象 + metadata: 元数据(chat_id, message_id, tool_ids, files 等) + model: 模型配置对象 + + Returns: + tuple: (form_data, metadata, events) + - form_data: 处理后的请求数据 + - metadata: 更新后的元数据 + - events: 需要发送给前端的事件列表(如引用来源) + """ + # === 1. 应用模型参数到请求 === form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") + # === 2. 处理 System Prompt 变量替换 === system_message = get_system_message(form_data.get("messages", [])) if system_message: # Chat Controls/User Settings try: + # 替换 system prompt 中的变量(如 {{USER_NAME}}, {{CURRENT_DATE}}) form_data = apply_system_prompt_to_body( system_message.get("content"), form_data, metadata, user, replace=True - ) # Required to handle system prompt variables + ) except: pass - event_emitter = get_event_emitter(metadata) - event_call = get_event_call(metadata) + # === 3. 初始化事件发射器和回调 === + event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器 + event_call = get_event_call(metadata) # 事件调用函数 + # === 4. 获取 OAuth Token === oauth_token = None try: if request.cookies.get("oauth_session_id", None): @@ -1023,9 +1050,10 @@ async def process_chat_payload(request, form_data, user, metadata, model): except Exception as e: log.error(f"Error getting OAuth token: {e}") + # === 5. 构建额外参数(供 Pipeline/Filter/Tools 使用)=== extra_params = { - "__event_emitter__": event_emitter, - "__event_call__": event_call, + "__event_emitter__": event_emitter, # 用于向前端发送实时事件 + "__event_call__": event_call, # 用于调用事件回调 "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, "__request__": request, @@ -1033,15 +1061,19 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__oauth_token__": oauth_token, } + # === 6. 确定模型列表和任务模型 === # Initialize events to store additional event to be sent to the client # Initialize contexts and citation if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + # Direct 模式:使用用户直连的模型 models = { request.state.model["id"]: request.state.model, } else: + # 标准模式:使用平台所有模型 models = request.app.state.MODELS + # 获取任务模型 ID(用于工具调用、标题生成等后台任务) task_model_id = get_task_model_id( form_data["model"], request.app.state.config.TASK_MODEL, @@ -1049,10 +1081,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): models, ) - events = [] - sources = [] + # === 7. 初始化事件和引用来源列表 === + events = [] # 需要发送给前端的事件(如 sources 引用) + sources = [] # RAG 检索到的文档来源 - # Folder "Project" handling + # === 8. Folder "Project" 处理 - 注入文件夹的 System Prompt 和文件 === # Check if the request has chat_id and is inside of a folder chat_id = metadata.get("chat_id", None) if chat_id and user: @@ -1061,21 +1094,24 @@ async def process_chat_payload(request, form_data, user, metadata, model): folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id) if folder and folder.data: + # 注入文件夹的 system prompt if "system_prompt" in folder.data: form_data = apply_system_prompt_to_body( folder.data["system_prompt"], form_data, metadata, user ) + # 注入文件夹关联的文件 if "files" in folder.data: form_data["files"] = [ *folder.data["files"], *form_data.get("files", []), ] - # Model "Knowledge" handling + # === 9. Model "Knowledge" 处理 - 注入模型绑定的知识库 === user_message = get_last_user_message(form_data["messages"]) model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) if model_knowledge: + # 向前端发送知识库搜索状态 await event_emitter( { "type": "status", @@ -1089,6 +1125,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): knowledge_files = [] for item in model_knowledge: + # 处理旧格式的 collection_name if item.get("collection_name"): knowledge_files.append( { @@ -1097,6 +1134,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): "legacy": True, } ) + # 处理新格式的 collection_names(多个集合) elif item.get("collection_names"): knowledge_files.append( { @@ -1109,12 +1147,14 @@ async def process_chat_payload(request, form_data, user, metadata, model): else: knowledge_files.append(item) + # 合并模型知识库文件和用户上传文件 files = form_data.get("files", []) files.extend(knowledge_files) form_data["files"] = files variables = form_data.pop("variables", None) + # === 10. Pipeline Inlet 处理 - 执行自定义 Python 插件 === # Process the form_data through the pipeline try: form_data = await process_pipeline_inlet_filter( @@ -1123,6 +1163,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): except Exception as e: raise e + # === 11. Filter Inlet 处理 - 执行函数过滤器 === try: filter_functions = [ Functions.get_function_by_id(filter_id) @@ -1141,23 +1182,28 @@ async def process_chat_payload(request, form_data, user, metadata, model): except Exception as e: raise Exception(f"{e}") + # === 12. 功能增强处理 (Features) === features = form_data.pop("features", None) if features: + # 12.1 记忆功能 - 注入历史对话记忆 if "memory" in features and features["memory"]: form_data = await chat_memory_handler( request, form_data, extra_params, user ) + # 12.2 网页搜索功能 - 执行网络搜索 if "web_search" in features and features["web_search"]: form_data = await chat_web_search_handler( request, form_data, extra_params, user ) + # 12.3 图像生成功能 - 处理图像生成请求 if "image_generation" in features and features["image_generation"]: form_data = await chat_image_generation_handler( request, form_data, extra_params, user ) + # 12.4 代码解释器功能 - 注入代码执行提示词 if "code_interpreter" in features and features["code_interpreter"]: form_data["messages"] = add_or_update_user_message( ( @@ -1168,33 +1214,33 @@ async def process_chat_payload(request, form_data, user, metadata, model): form_data["messages"], ) + # === 13. 提取工具和文件信息 === tool_ids = form_data.pop("tool_ids", None) files = form_data.pop("files", None) + # === 14. 文件夹类型文件展开 === prompt = get_last_user_message(form_data["messages"]) - # TODO: re-enable URL extraction from prompt - # urls = [] - # if prompt and len(prompt or "") < 500 and (not files or len(files) == 0): - # urls = extract_urls(prompt) if files: if not files: files = [] for file_item in files: + # 如果文件类型是 folder,展开其包含的所有文件 if file_item.get("type", "file") == "folder": # Get folder files folder_id = file_item.get("id", None) if folder_id: folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id) if folder and folder.data and "files" in folder.data: + # 移除文件夹项,添加文件夹内的文件 files = [f for f in files if f.get("id", None) != folder_id] files = [*files, *folder.data["files"]] - # files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]] - # Remove duplicate files based on their content + # 去重文件(基于文件内容) files = list({json.dumps(f, sort_keys=True): f for f in files}.values()) + # === 15. 更新元数据 === metadata = { **metadata, "tool_ids": tool_ids, @@ -1202,25 +1248,28 @@ async def process_chat_payload(request, form_data, user, metadata, model): } form_data["metadata"] = metadata + # === 16. 准备工具字典 === # Server side tools - tool_ids = metadata.get("tool_ids", None) + tool_ids = metadata.get("tool_ids", None) # 服务器端工具 ID # Client side tools - direct_tool_servers = metadata.get("tool_servers", None) + direct_tool_servers = metadata.get("tool_servers", None) # 客户端直连工具服务器 log.debug(f"{tool_ids=}") log.debug(f"{direct_tool_servers=}") - tools_dict = {} + tools_dict = {} # 所有工具的字典 - mcp_clients = {} - mcp_tools_dict = {} + mcp_clients = {} # MCP (Model Context Protocol) 客户端 + mcp_tools_dict = {} # MCP 工具字典 if tool_ids: for tool_id in tool_ids: + # === 16.1 处理 MCP (Model Context Protocol) 工具 === if tool_id.startswith("server:mcp:"): try: server_id = tool_id[len("server:mcp:") :] + # 查找 MCP 服务器连接配置 mcp_server_connection = None for ( server_connection @@ -1236,6 +1285,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.error(f"MCP server with id {server_id} not found") continue + # 处理认证类型 auth_type = mcp_server_connection.get("auth_type", "") headers = {} @@ -1244,7 +1294,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): f"Bearer {mcp_server_connection.get('key', '')}" ) elif auth_type == "none": - # No authentication + # 无需认证 pass elif auth_type == "session": headers["Authorization"] = ( @@ -1273,16 +1323,19 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.error(f"Error getting OAuth token: {e}") oauth_token = None + # 连接到 MCP 服务器 mcp_clients[server_id] = MCPClient() await mcp_clients[server_id].connect( url=mcp_server_connection.get("url", ""), headers=headers if headers else None, ) + # 获取 MCP 工具列表并注册 tool_specs = await mcp_clients[server_id].list_tool_specs() for tool_spec in tool_specs: def make_tool_function(client, function_name): + """为每个 MCP 工具创建异步调用函数""" async def tool_function(**kwargs): return await client.call_tool( function_name, @@ -1295,6 +1348,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): mcp_clients[server_id], tool_spec["name"] ) + # 注册 MCP 工具到工具字典 mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { "spec": { **tool_spec, @@ -1309,6 +1363,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.debug(e) continue + # === 16.2 获取标准工具(Function Tools)=== tools_dict = await get_tools( request, tool_ids, @@ -1320,9 +1375,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__files__": metadata.get("files", []), }, ) + # 合并 MCP 工具 if mcp_tools_dict: tools_dict = {**tools_dict, **mcp_tools_dict} + # === 16.3 处理客户端直连工具服务器 === if direct_tool_servers: for tool_server in direct_tool_servers: tool_specs = tool_server.pop("specs", []) @@ -1334,19 +1391,21 @@ async def process_chat_payload(request, form_data, user, metadata, model): "server": tool_server, } + # 保存 MCP 客户端到元数据(用于最后清理) if mcp_clients: metadata["mcp_clients"] = mcp_clients + # === 17. 工具调用处理 === if tools_dict: if metadata.get("params", {}).get("function_calling") == "native": - # If the function calling is native, then call the tools function calling handler + # 原生函数调用模式:直接传递给 LLM metadata["tools"] = tools_dict form_data["tools"] = [ {"type": "function", "function": tool.get("spec", {})} for tool in tools_dict.values() ] else: - # If the function calling is not native, then call the tools function calling handler + # 默认模式:通过 Prompt 实现工具调用 try: form_data, flags = await chat_completion_tools_handler( request, form_data, extra_params, user, models, tools_dict @@ -1355,6 +1414,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): except Exception as e: log.exception(e) + # === 18. 文件处理 - RAG 检索 === try: form_data, flags = await chat_completion_files_handler( request, form_data, extra_params, user @@ -1363,11 +1423,13 @@ async def process_chat_payload(request, form_data, user, metadata, model): except Exception as e: log.exception(e) + # === 19. 构建上下文字符串并注入到消息 === # If context is not empty, insert it into the messages if len(sources) > 0: context_string = "" - citation_idx_map = {} + citation_idx_map = {} # 引用索引映射(文档 ID → 引用编号) + # 遍历所有来源,构建上下文字符串 for source in sources: if "document" in source: for document_text, document_metadata in zip( @@ -1380,9 +1442,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): or "N/A" ) + # 为每个来源分配唯一的引用编号 if source_id not in citation_idx_map: citation_idx_map[source_id] = len(citation_idx_map) + 1 + # 构建 XML 格式的来源标签 context_string += ( f' 0: events.append({"sources": sources}) + # === 21. 完成知识库搜索状态 === if model_knowledge: await event_emitter( { @@ -1434,29 +1501,63 @@ async def process_chat_payload(request, form_data, user, metadata, model): async def process_chat_response( request, response, form_data, user, metadata, model, events, tasks ): + """ + 处理聊天响应 - 规范化/分发聊天完成响应 + + 这是聊天响应后处理的核心函数,负责: + 1. 处理流式(SSE/ndjson)和非流式(JSON)响应 + 2. 通过 WebSocket 发送事件到前端 + 3. 持久化消息/错误/元数据到数据库 + 4. 触发后台任务(标题生成、标签生成、Follow-ups) + 5. 流式响应时:消费上游数据块、重建内容、处理工具调用、转发增量 + + Args: + request: FastAPI Request 对象 + response: 上游响应(dict/JSONResponse/StreamingResponse) + form_data: 原始请求 payload(可能被上游修改) + user: 已验证的用户对象 + metadata: 早期收集的聊天/会话上下文 + model: 解析后的模型配置 + events: 需要发送的额外事件(如 sources 引用) + tasks: 可选的后台任务(title/tags/follow-ups) + + Returns: + - 非流式: dict (OpenAI JSON 格式) + - 流式: StreamingResponse (SSE/ndjson 格式) + """ + + # === 内部函数:后台任务处理器 === async def background_tasks_handler(): + """ + 在响应完成后执行后台任务: + 1. Follow-ups 生成 - 生成后续问题建议 + 2. Title 生成 - 自动生成聊天标题 + 3. Tags 生成 - 自动生成聊天标签 + """ message = None messages = [] + # 获取消息历史 if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"): + # 从数据库获取持久化的聊天历史 messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"]) message = messages_map.get(metadata["message_id"]) if messages_map else None message_list = get_message_list(messages_map, metadata["message_id"]) - # Remove details tags and files from the messages. - # as get_message_list creates a new list, it does not affect - # the original messages outside of this handler - + # 清理消息内容:移除 details 标签和文件 + # get_message_list 创建新列表,不影响原始消息 messages = [] for message in message_list: content = message.get("content", "") + # 处理多模态内容(图片 + 文本) if isinstance(content, list): for item in content: if item.get("type") == "text": content = item["text"] break + # 移除
标签和 Markdown 图片 if isinstance(content, str): content = re.sub( r"]*>.*?<\/details>|!\[.*?\]\(.*?\)", @@ -1468,21 +1569,21 @@ async def process_chat_response( messages.append( { **message, - "role": message.get( - "role", "assistant" - ), # Safe fallback for missing role + "role": message.get("role", "assistant"), # 安全回退 "content": content, } ) else: - # Local temp chat, get the model and message from the form_data + # 临时聊天(local:):从 form_data 获取 message = get_last_user_message_item(form_data.get("messages", [])) messages = form_data.get("messages", []) if message: message["model"] = form_data.get("model") + # 执行后台任务 if message and "model" in message: if tasks and messages: + # === 任务 1: Follow-ups 生成 === if ( TASKS.FOLLOW_UP_GENERATION in tasks and tasks[TASKS.FOLLOW_UP_GENERATION] @@ -1510,6 +1611,7 @@ async def process_chat_response( else: follow_ups_string = "" + # 提取 JSON 对象(从第一个 { 到最后一个 }) follow_ups_string = follow_ups_string[ follow_ups_string.find("{") : follow_ups_string.rfind("}") + 1 @@ -1519,6 +1621,7 @@ async def process_chat_response( follow_ups = json.loads(follow_ups_string).get( "follow_ups", [] ) + # 通过 WebSocket 发送 Follow-ups await event_emitter( { "type": "chat:message:follow_ups", @@ -1528,6 +1631,7 @@ async def process_chat_response( } ) + # 持久化到数据库 if not metadata.get("chat_id", "").startswith("local:"): Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], @@ -1540,9 +1644,9 @@ async def process_chat_response( except Exception as e: pass - if not metadata.get("chat_id", "").startswith( - "local:" - ): # Only update titles and tags for non-temp chats + # === 任务 2 & 3: 标题和标签生成(仅非临时聊天)=== + if not metadata.get("chat_id", "").startswith("local:"): + # 任务 2: 标题生成 if ( TASKS.TITLE_GENERATION in tasks and tasks[TASKS.TITLE_GENERATION] @@ -1579,6 +1683,7 @@ async def process_chat_response( else: title_string = "" + # 提取 JSON 对象 title_string = title_string[ title_string.find("{") : title_string.rfind("}") + 1 ] @@ -1593,16 +1698,19 @@ async def process_chat_response( if not title: title = messages[0].get("content", user_message) + # 更新数据库 Chats.update_chat_title_by_id( metadata["chat_id"], title ) + # 通过 WebSocket 发送标题 await event_emitter( { "type": "chat:title", "data": title, } ) + # 如果只有 2 条消息(首次对话),直接用用户消息作为标题 elif len(messages) == 2: title = messages[0].get("content", user_message) @@ -1615,6 +1723,7 @@ async def process_chat_response( } ) + # 任务 3: 标签生成 if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: res = await generate_chat_tags( request, @@ -1638,16 +1747,19 @@ async def process_chat_response( else: tags_string = "" + # 提取 JSON 对象 tags_string = tags_string[ tags_string.find("{") : tags_string.rfind("}") + 1 ] try: tags = json.loads(tags_string).get("tags", []) + # 更新数据库 Chats.update_chat_tags_by_id( metadata["chat_id"], tags, user ) + # 通过 WebSocket 发送标签 await event_emitter( { "type": "chat:tags", @@ -1657,6 +1769,7 @@ async def process_chat_response( except Exception as e: pass + # === 1. 解析事件发射器(仅在有实时会话时使用)=== event_emitter = None event_caller = None if ( @@ -1667,18 +1780,19 @@ async def process_chat_response( and "message_id" in metadata and metadata["message_id"] ): - event_emitter = get_event_emitter(metadata) - event_caller = get_event_call(metadata) + event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器 + event_caller = get_event_call(metadata) # 事件调用器 - # Non-streaming response + # === 2. 非流式响应处理 === if not isinstance(response, StreamingResponse): if event_emitter: try: if isinstance(response, dict) or isinstance(response, JSONResponse): + # 处理单项列表(解包) if isinstance(response, list) and len(response) == 1: - # If the response is a single-item list, unwrap it #17213 response = response[0] + # 解析 JSONResponse 的 body if isinstance(response, JSONResponse) and isinstance( response.body, bytes ): @@ -1693,6 +1807,7 @@ async def process_chat_response( else: response_data = response + # 处理错误响应 if "error" in response_data: error = response_data.get("error") @@ -1701,6 +1816,7 @@ async def process_chat_response( else: error = str(error) + # 保存错误到数据库 Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], @@ -1708,6 +1824,7 @@ async def process_chat_response( "error": {"content": error}, }, ) + # 通过 WebSocket 发送错误事件 if isinstance(error, str) or isinstance(error, dict): await event_emitter( { @@ -1821,7 +1938,7 @@ async def process_chat_response( return response - # Non standard response + # Non standard response (not SSE/ndjson) if not any( content_type in response.headers["Content-Type"] for content_type in ["text/event-stream", "application/x-ndjson"] @@ -1854,7 +1971,7 @@ async def process_chat_response( ) ] - # Streaming response + # Streaming response: consume upstream SSE/ndjson, forward deltas/events, handle tools if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. model_id = form_data.get("model", "") @@ -2289,26 +2406,49 @@ async def process_chat_response( ) async def stream_body_handler(response, form_data): - nonlocal content - nonlocal content_blocks + """ + 流式响应体处理器 - 消费上游 SSE 流并处理内容块 - response_tool_calls = [] + 核心功能: + 1. 解析 SSE (Server-Sent Events) 流 + 2. 处理工具调用(tool_calls)增量更新 + 3. 处理推理内容(reasoning_content)- 如 标签 + 4. 处理解决方案(solution)和代码解释器(code_interpreter)标签 + 5. 实时保存消息到数据库(可选) + 6. 控制流式推送频率(delta throttling)避免 WebSocket 过载 - delta_count = 0 + Args: + response: 上游 StreamingResponse 对象 + form_data: 原始请求数据 + """ + nonlocal content # 累积的完整文本内容 + nonlocal content_blocks # 内容块列表(text/reasoning/code_interpreter) + + response_tool_calls = [] # 累积的工具调用列表 + + # === 1. 初始化流式推送控制 === + delta_count = 0 # 当前累积的 delta 数量 delta_chunk_size = max( - CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, # 全局配置的块大小 int( metadata.get("params", {}).get("stream_delta_chunk_size") or 1 - ), + ), # 用户配置的块大小 ) - last_delta_data = None + last_delta_data = None # 待发送的最后一个 delta 数据 async def flush_pending_delta_data(threshold: int = 0): + """ + 刷新待发送的 delta 数据 + + Args: + threshold: 阈值,当 delta_count >= threshold 时才发送 + """ nonlocal delta_count nonlocal last_delta_data if delta_count >= threshold and last_delta_data: + # 通过 WebSocket 发送累积的 delta await event_emitter( { "type": "chat:completion", @@ -2318,7 +2458,9 @@ async def process_chat_response( delta_count = 0 last_delta_data = None + # === 2. 消费 SSE 流 === async for line in response.body_iterator: + # 解码字节流为字符串 line = ( line.decode("utf-8", "replace") if isinstance(line, bytes) @@ -2326,20 +2468,22 @@ async def process_chat_response( ) data = line - # Skip empty lines + # 跳过空行 if not data.strip(): continue - # "data:" is the prefix for each event + # SSE 格式:每个事件以 "data:" 开头 if not data.startswith("data:"): continue - # Remove the prefix + # 移除 "data:" 前缀 data = data[len("data:") :].strip() try: + # 解析 JSON 数据 data = json.loads(data) + # === 3. 执行 Filter 函数(stream 类型)=== data, _ = await process_filter_functions( request=request, filter_functions=filter_functions, @@ -2349,11 +2493,14 @@ async def process_chat_response( ) if data: + # 处理自定义事件 if "event" in data: await event_emitter(data.get("event", {})) + # === 4. 处理 Arena 模式的模型选择 === if "selected_model_id" in data: model_id = data["selected_model_id"] + # 保存选中的模型 ID 到数据库 Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], @@ -2370,9 +2517,9 @@ async def process_chat_response( else: choices = data.get("choices", []) - # 17421 + # === 5. 处理 usage 和 timings 信息 === usage = data.get("usage", {}) or {} - usage.update(data.get("timings", {})) # llama.cpp + usage.update(data.get("timings", {})) # llama.cpp timings if usage: await event_emitter( { @@ -2383,6 +2530,7 @@ async def process_chat_response( } ) + # === 6. 处理错误响应 === if not choices: error = data.get("error", {}) if error: @@ -2396,9 +2544,11 @@ async def process_chat_response( ) continue + # === 7. 提取 delta(增量内容)=== delta = choices[0].get("delta", {}) delta_tool_calls = delta.get("tool_calls", None) + # === 8. 处理工具调用(Tool Calls)=== if delta_tool_calls: for delta_tool_call in delta_tool_calls: tool_call_index = delta_tool_call.get( @@ -2406,7 +2556,7 @@ async def process_chat_response( ) if tool_call_index is not None: - # Check if the tool call already exists + # 查找已存在的工具调用 current_response_tool_call = None for ( response_tool_call @@ -2421,7 +2571,7 @@ async def process_chat_response( break if current_response_tool_call is None: - # Add the new tool call + # 添加新的工具调用 delta_tool_call.setdefault( "function", {} ) @@ -2435,7 +2585,7 @@ async def process_chat_response( delta_tool_call ) else: - # Update the existing tool call + # 更新已存在的工具调用(累积 name 和 arguments) delta_name = delta_tool_call.get( "function", {} ).get("name") @@ -2457,14 +2607,17 @@ async def process_chat_response( "arguments" ] += delta_arguments + # === 9. 处理文本内容 === value = delta.get("content") + # === 10. 处理推理内容(Reasoning Content)=== reasoning_content = ( delta.get("reasoning_content") or delta.get("reasoning") or delta.get("thinking") ) if reasoning_content: + # 创建或更新 reasoning 内容块 if ( not content_blocks or content_blocks[-1]["type"] != "reasoning" @@ -2483,6 +2636,7 @@ async def process_chat_response( else: reasoning_block = content_blocks[-1] + # 累积推理内容 reasoning_block["content"] += reasoning_content data = { @@ -2491,7 +2645,9 @@ async def process_chat_response( ) } + # === 11. 处理普通文本内容 === if value: + # 如果上一个块是 reasoning,标记结束并创建新的文本块 if ( content_blocks and content_blocks[-1]["type"] @@ -2515,6 +2671,7 @@ async def process_chat_response( } ) + # 累积文本内容 content = f"{content}{value}" if not content_blocks: content_blocks.append( @@ -2528,6 +2685,8 @@ async def process_chat_response( content_blocks[-1]["content"] + value ) + # === 12. 检测并处理特殊标签 === + # 12.1 Reasoning 标签检测(如 ) if DETECT_REASONING_TAGS: content, content_blocks, _ = ( tag_content_handler( @@ -2538,6 +2697,7 @@ async def process_chat_response( ) ) + # Solution 标签检测 content, content_blocks, _ = ( tag_content_handler( "solution", @@ -2547,6 +2707,7 @@ async def process_chat_response( ) ) + # 12.2 Code Interpreter 标签检测 if DETECT_CODE_INTERPRETER: content, content_blocks, end = ( tag_content_handler( @@ -2557,11 +2718,13 @@ async def process_chat_response( ) ) + # 如果检测到结束标签,停止流式处理 if end: break + # === 13. 实时保存消息(可选)=== if ENABLE_REALTIME_CHAT_SAVE: - # Save message in the database + # 保存到数据库 Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], @@ -2572,18 +2735,22 @@ async def process_chat_response( }, ) else: + # 准备待发送的数据 data = { "content": serialize_content_blocks( content_blocks ), } + # === 14. 流式推送控制 === if delta: delta_count += 1 last_delta_data = data + # 达到阈值时刷新 if delta_count >= delta_chunk_size: await flush_pending_delta_data(delta_chunk_size) else: + # 非 delta 数据立即发送 await event_emitter( { "type": "chat:completion", @@ -2591,24 +2758,30 @@ async def process_chat_response( } ) except Exception as e: + # 处理流结束标记 done = "data: [DONE]" in line if done: pass else: log.debug(f"Error: {e}") continue + + # === 15. 刷新剩余的 delta 数据 === await flush_pending_delta_data() + # === 16. 清理内容块 === if content_blocks: - # Clean up the last text block + # 清理最后一个文本块(移除尾部空白) if content_blocks[-1]["type"] == "text": content_blocks[-1]["content"] = content_blocks[-1][ "content" ].strip() + # 如果为空则移除 if not content_blocks[-1]["content"]: content_blocks.pop() + # 确保至少有一个空文本块 if not content_blocks: content_blocks.append( { @@ -2617,6 +2790,7 @@ async def process_chat_response( } ) + # 标记最后一个 reasoning 块结束 if content_blocks[-1]["type"] == "reasoning": reasoning_block = content_blocks[-1] if reasoning_block.get("ended_at") is None: @@ -2626,9 +2800,11 @@ async def process_chat_response( - reasoning_block["started_at"] ) + # === 17. 保存工具调用 === if response_tool_calls: tool_calls.append(response_tool_calls) + # === 18. 执行响应清理(关闭连接)=== if response.background: await response.background() diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 9984e378fb..1a7b2a04cd 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -67,6 +67,45 @@ def get_messages_content(messages: list[dict]) -> str: ) +def extract_timestamped_messages(raw_msgs: list[dict]) -> list[dict]: + """ + 将消息列表转换为统一的字典结构,便于下游持久化/审计。 + + Args: + raw_msgs (list[dict]): OpenAI 格式的消息列表。 + + Returns: + list[dict]: 每条消息包含 role、content、timestamp 字段。 + """ + messages: list[dict] = [] + for msg in raw_msgs: + if not isinstance(msg, dict): + continue + ts = ( + msg.get("createdAt") + or msg.get("created_at") + or msg.get("timestamp") + or msg.get("updated_at") + or msg.get("updatedAt") + or 0 + ) + content = msg.get("content", "") or "" + if isinstance(content, list): + for item in content: + if item.get("type") == "text": + content = item.get("text", "") + break + messages.append( + { + "role": msg.get("role", "assistant"), + "content": str(content), + "timestamp": int(ts), + } + ) + + return messages + + def get_last_user_message_item(messages: list[dict]) -> Optional[dict]: for message in reversed(messages): if message["role"] == "user":