diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 1960fde666..4b92737e23 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -621,6 +621,15 @@ else: except Exception: CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 +# 全局调试开关(默认开启) +CHAT_DEBUG_FLAG = os.environ.get("CHAT_DEBUG_FALG", "True").lower() == "true" + +# 摘要/聊天相关的默认阈值 +SUMMARY_TOKEN_THRESHOLD_DEFAULT = os.environ.get("SUMMARY_TOKEN_THRESHOLD", "3000") +try: + SUMMARY_TOKEN_THRESHOLD_DEFAULT = int(SUMMARY_TOKEN_THRESHOLD_DEFAULT) +except Exception: + SUMMARY_TOKEN_THRESHOLD_DEFAULT = 3000 #################################### # WEBSOCKET SUPPORT diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7ee7b247f6..3aa5ad003e 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -465,6 +465,8 @@ from open_webui.env import ( EXTERNAL_PWA_MANIFEST_URL, AIOHTTP_CLIENT_SESSION_SSL, ENABLE_STAR_SESSIONS_MIDDLEWARE, + + CHAT_DEBUG_FLAG, ) @@ -481,6 +483,12 @@ from open_webui.utils.chat import ( chat_action as chat_action_handler, ) from open_webui.utils.misc import get_message_list +from open_webui.utils.summary import ( + summarize, + compute_token_count, + build_ordered_messages, + get_recent_messages_by_user_id, +) from open_webui.utils.embeddings import generate_embeddings from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.access_control import has_access @@ -1619,7 +1627,69 @@ async def chat_completion( # === 8. 定义内部处理函数 process_chat === async def process_chat(request, form_data, user, metadata, model): """处理完整的聊天流程:Payload 处理 → LLM 调用 → 响应处理""" + + async def ensure_initial_summary(): + """ + 如果是新聊天,其中没有summary,获得最近的若干次互动,生成一次摘要并保存。 + 触发条件:非 local 会话,无已有摘要。 + """ + + # 获取 chat_id(跳过本地会话) + chat_id = metadata.get("chat_id") + if not chat_id or str(chat_id).startswith("local:"): + return + + try: + # 检查是否已有摘要 + old_summary = Chats.get_summary_by_user_id_and_chat_id(user.id, chat_id) + if CHAT_DEBUG_FLAG: + print(f"[summary:init] chat_id={chat_id} 现有摘要={bool(old_summary)}") + if old_summary: + if CHAT_DEBUG_FLAG: + print(f"[summary:init] chat_id={chat_id} 已存在摘要,跳过生成") + return + + # 获取消息列表 + ordered = get_recent_messages_by_user_id(user.id, chat_id, 100) + if CHAT_DEBUG_FLAG: + print(f"[summary:init] chat_id={chat_id} 最近消息数={len(ordered)} (优先当前会话)") + + if not ordered: + if CHAT_DEBUG_FLAG: + print(f"[summary:init] chat_id={chat_id} 无可用消息,跳过生成") + return + + # 调用 LLM 生成摘要并保存 + summary_text = summarize(ordered, None) + last_id = ordered[-1].get("id") if ordered else None + recent_ids = [m.get("id") for m in ordered[-20:] if m.get("id")] # 记录最近20条消息为冷启动消息 + + if CHAT_DEBUG_FLAG: + print( + f"[summary:init] chat_id={chat_id} 生成首条摘要,msg_count={len(ordered)}, last_id={last_id}, recent_ids={len(recent_ids)}" + ) + + print("[summary:init]: ordered") + for i in ordered: + print(i['role'], " ", i['content'][:100]) + + res = Chats.set_summary_by_user_id_and_chat_id( + user.id, + chat_id, + summary_text, + last_id, + int(time.time()), + recent_message_ids=recent_ids, + ) + if not res: + if CHAT_DEBUG_FLAG: + print(f"[summary:init] chat_id={chat_id} 写入摘要失败") + except Exception as e: + log.exception(f"initial summary failed: {e}") + try: + await ensure_initial_summary() + # 8.1 Payload 预处理:执行 Pipeline Filters、工具注入、RAG 检索等 # remark:并不涉及消息的持久化,只涉及发送给 LLM 前,上下文的封装 form_data, metadata, events = await process_chat_payload( @@ -1661,7 +1731,7 @@ async def chat_completion( # 8.6 异常处理:记录错误到数据库并通知前端 except Exception as e: - log.debug(f"Error processing chat payload: {e}") + log.exception(f"Error processing chat payload: {e}") if metadata.get("chat_id") and metadata.get("message_id"): try: # 将错误信息保存到消息记录 diff --git a/backend/open_webui/memory/cross_window_memory.py b/backend/open_webui/memory/cross_window_memory.py index c8e82521e8..9d6f6150f9 100644 --- a/backend/open_webui/memory/cross_window_memory.py +++ b/backend/open_webui/memory/cross_window_memory.py @@ -16,6 +16,7 @@ def last_process_payload( messages (List[Dict]): 该用户在该对话下的聊天消息列表, 形如 {"role": "system|user|assistant", "content": "...", "timestamp": 0}。 """ - print("user_id:", user_id) - print("session_id:", session_id) - print("messages:", messages) + return + # print("user_id:", user_id) + # print("session_id:", session_id) + # print("messages:", messages) diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index cfcbc004b7..ccf3b24052 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -252,6 +252,62 @@ class ChatTable: return chat.chat.get("history", {}).get("messages", {}).get(message_id, {}) + def get_summary_by_user_id_and_chat_id( + self, user_id: str, chat_id: str + ) -> Optional[dict]: + """ + 读取 chat.meta.summary,包含摘要内容及摘要边界(last_message_id/timestamp)。 + """ + chat = self.get_chat_by_id_and_user_id(chat_id, user_id) + if chat is None: + return None + + return chat.meta.get("summary", None) if isinstance(chat.meta, dict) else None + + def set_summary_by_user_id_and_chat_id( + self, + user_id: str, + chat_id: str, + summary: str, + last_message_id: Optional[str], + last_timestamp: Optional[int], + recent_message_ids: Optional[list[str]] = None, + ) -> Optional[ChatModel]: + """ + 写入 chat.meta.summary,并更新更新时间。 + """ + try: + with get_db() as db: + chat = db.query(Chat).filter_by(id=chat_id, user_id=user_id).first() + + if chat is None: + return None + + meta = chat.meta if isinstance(chat.meta, dict) else {} + new_meta = { + **meta, + "summary": { + "content": summary, + "last_message_id": last_message_id, + "last_timestamp": last_timestamp, + }, + **( + {"recent_message_id_for_cold_start": recent_message_ids} + if recent_message_ids is not None + else {} + ), + } + + # 重新赋值以触发 SQLAlchemy 变更检测 + chat.meta = new_meta + chat.updated_at = int(time.time()) + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception as e: + log.exception(f"set_summary_by_user_id_and_chat_id failed: {e}") + return None + def upsert_message_to_chat_by_id_and_message_id( self, id: str, message_id: str, message: dict ) -> Optional[ChatModel]: diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 90b3f7b171..1f1f527ea7 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -1004,12 +1004,6 @@ async def generate_chat_completion( log.debug( f"chatting_completion hook user={user.id} chat_id={metadata.get('chat_id')} model={payload.get('model')}" ) - - last_process_payload( - user_id = user.id, - session_id = metadata.get("chat_id"), - messages = extract_timestamped_messages(payload.get("messages", [])), - ) except Exception as e: log.debug(f"chatting_completion 钩子执行失败: {e}") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index c9aa947436..f1e8a35be3 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -101,6 +101,11 @@ from open_webui.utils.filter import ( from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.payload import apply_system_prompt_to_body from open_webui.utils.mcp.client import MCPClient +from open_webui.utils.summary import ( + summarize, + compute_token_count, + slice_messages_with_summary, +) from open_webui.config import ( @@ -114,9 +119,12 @@ from open_webui.env import ( GLOBAL_LOG_LEVEL, CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES, + SUMMARY_TOKEN_THRESHOLD_DEFAULT, BYPASS_MODEL_ACCESS_CONTROL, ENABLE_REALTIME_CHAT_SAVE, ENABLE_QUERIES_CACHE, + + CHAT_DEBUG_FLAG, ) from open_webui.constants import TASKS @@ -1031,14 +1039,35 @@ async def process_chat_payload(request, form_data, user, metadata, model): 处理聊天请求的 Payload - 执行 Pipeline、Filter、功能增强和工具注入 这是聊天请求预处理的核心函数,按以下顺序执行: - 1. Pipeline Inlet (管道入口) - 自定义 Python 插件预处理 - 2. Filter Inlet (过滤器入口) - 函数过滤器预处理 - 3. Chat Memory (记忆) - 注入历史对话记忆 - 4. Chat Web Search (网页搜索) - 执行网络搜索并注入结果 - 5. Chat Image Generation (图像生成) - 处理图像生成请求 - 6. Chat Code Interpreter (代码解释器) - 注入代码执行提示词 - 7. Chat Tools Function Calling (工具调用) - 处理函数/工具调用 - 8. Chat Files (文件处理) - 处理上传文件、知识库文件、RAG 检索 + 1. Pipeline Inlet (管道入口) - 自定义 Python 插件预处理 [已屏蔽] + 2. Filter Inlet (过滤器入口) - 函数过滤器预处理 [已屏蔽] + 3. Chat Memory (记忆) - 注入历史对话记忆 [激活] + 4. Chat Web Search (网页搜索) - 执行网络搜索并注入结果 [已屏蔽] + 5. Chat Image Generation (图像生成) - 处理图像生成请求 [已屏蔽] + 6. Chat Code Interpreter (代码解释器) - 注入代码执行提示词 [已屏蔽] + 7. Chat Tools Function Calling (工具调用) - 处理函数/工具调用 [已屏蔽] + 8. Chat Files (文件处理) - 处理上传文件、知识库文件、RAG 检索 [已屏蔽] + + === 功能屏蔽说明 === + 当前已屏蔽功能(使用 if False 跳过): + - Folder "Project" System Prompt 和文件注入 + - Model "Knowledge" 知识库注入 + - Pipeline Inlet 自定义插件处理 + - Filter Inlet 函数过滤器处理 + - 图像生成功能 + - 代码解释器功能 + - MCP (Model Context Protocol) 工具连接 + - 标准工具(Function Tools)获取与调用 + - 文件处理与 RAG 检索(向量数据库检索、上下文注入) + + 当前激活功能: + - 模型参数应用 + - System Prompt 变量替换 + - OAuth Token 获取 + - 记忆功能(Memory) + - 网页搜索(Web Search) + - 客户端直连工具服务器(Direct Tool Servers) + - 文件夹类型文件展开 Args: request: FastAPI Request 对象 @@ -1057,6 +1086,87 @@ async def process_chat_payload(request, form_data, user, metadata, model): form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") + # === 1 基于摘要的上下文裁剪与系统提示 === + # - backend/open_webui/utils/middleware.py:process_chat_payload: + # 每次请求先读取存量摘要,按 last_summary_id 切片"前 20 条 + 之后全部"消息,再用摘要构建唯一的 system prompt 置前; + chat_id = metadata.get("chat_id", None) + summary_record = None + ordered_messages = form_data.get("messages", []) + cold_start_ids = [] + + # 1.1 读取摘要记录并裁剪消息列表 + if chat_id and user and not str(chat_id).startswith("local:"): + try: + # 获取当前会话的摘要记录 + summary_record = Chats.get_summary_by_user_id_and_chat_id( + user.id, chat_id + ) + # 获取冷启动消息 ID 列表(用于注入关键历史消息) + chat_item = Chats.get_chat_by_id_and_user_id(chat_id, user.id) + if chat_item and isinstance(chat_item.meta, dict): + cold_start_ids = chat_item.meta.get("recent_message_id_for_cold_start", []) or [] + + if CHAT_DEBUG_FLAG: + print( + f"[summary:payload] chat_id={chat_id} 有摘要={bool(summary_record)} " + f"上次摘要的 message_id={summary_record.get('last_message_id') if summary_record else None}" + ) + # 基于摘要边界裁剪消息:保留边界前 20 条 + 边界后全部 + messages_map = Chats.get_messages_map_by_chat_id(chat_id) or {} + anchor_id = metadata.get("message_id") + ordered_messages = slice_messages_with_summary( + messages_map, + summary_record.get("last_message_id") if summary_record else None, + anchor_id, + pre_boundary=20, + ) + + if CHAT_DEBUG_FLAG: + print("[summary:payload]: summary前 20 条 + summary 后全部") + for i in ordered_messages: + print(i['role'], " ", i['content'][:100]) + print( + f"[summary:payload] chat_id={chat_id} 切片后消息数={len(ordered_messages)} 当前锚点={anchor_id}" + ) + except Exception as e: + print(f"summary preprocessing failed: {e}") + + # 1.2 追加冷启动消息(避免重要上下文丢失) + if cold_start_ids and chat_id and not str(chat_id).startswith("local:"): + messages_map = Chats.get_messages_map_by_chat_id(chat_id) or {} + seen_ids = {m.get("id") for m in ordered_messages if m.get("id")} + if CHAT_DEBUG_FLAG: + print("[summary:payload:cold_start]") + for mid in cold_start_ids: + msg = messages_map.get(mid) + if not msg: + continue + print(msg['role'], " ", msg['content'][:100]) + if mid in seen_ids: # 跳过已存在的消息 + continue + ordered_messages.append({**msg, "id": mid}) + seen_ids.add(mid) + + if CHAT_DEBUG_FLAG: + print( + f"[summary:payload:cold_start] chat_id={chat_id} 追加冷启动消息数={len(cold_start_ids)}" + ) + + # 1.3 注入摘要系统提示(替换原有 system 消息) + if summary_record and summary_record.get("content"): + summary_system_message = { + "role": "system", + "content": f"Conversation History Summary:\n{summary_record.get('content', '')}", + } + # 移除旧的 system 消息,插入摘要系统提示到开头 + ordered_messages = [ + m for m in ordered_messages if m.get("role") != "system" + ] + ordered_messages = [summary_system_message, *ordered_messages] + + if ordered_messages: + form_data["messages"] = ordered_messages + # === 2. 处理 System Prompt 变量替换 === system_message = get_system_message(form_data.get("messages", [])) if system_message: # Chat Controls/User Settings @@ -1118,7 +1228,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): events = [] # 需要发送给前端的事件(如 sources 引用) sources = [] # RAG 检索到的文档来源 - # === 8. Folder "Project" 处理 - 注入文件夹的 System Prompt 和文件 === + # === 8. Folder "Project" 处理 - 注入文件夹的 System Prompt 和文件 [已屏蔽] === # Check if the request has chat_id and is inside of a folder chat_id = metadata.get("chat_id", None) if False: @@ -1140,81 +1250,83 @@ async def process_chat_payload(request, form_data, user, metadata, model): *form_data.get("files", []), ] - # === 9. Model "Knowledge" 处理 - 注入模型绑定的知识库 === + # === 9. Model "Knowledge" 处理 - 注入模型绑定的知识库 [已屏蔽] === user_message = get_last_user_message(form_data["messages"]) model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) - if model_knowledge: - # 向前端发送知识库搜索状态 - await event_emitter( - { - "type": "status", - "data": { - "action": "knowledge_search", - "query": user_message, - "done": False, - }, - } - ) - - knowledge_files = [] - for item in model_knowledge: - # 处理旧格式的 collection_name - if item.get("collection_name"): - knowledge_files.append( - { - "id": item.get("collection_name"), - "name": item.get("name"), - "legacy": True, - } - ) - # 处理新格式的 collection_names(多个集合) - elif item.get("collection_names"): - knowledge_files.append( - { - "name": item.get("name"), - "type": "collection", - "collection_names": item.get("collection_names"), - "legacy": True, - } - ) - else: - knowledge_files.append(item) - - # 合并模型知识库文件和用户上传文件 - files = form_data.get("files", []) - files.extend(knowledge_files) - form_data["files"] = files - - variables = form_data.pop("variables", None) - - # === 10. Pipeline Inlet 处理 - 执行自定义 Python 插件 === - # Process the form_data through the pipeline - try: - form_data = await process_pipeline_inlet_filter( - request, form_data, user, models - ) - except Exception as e: - raise e - - # === 11. Filter Inlet 处理 - 执行函数过滤器 === - try: - filter_functions = [ - Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids( - request, model, metadata.get("filter_ids", []) + if False: + if model_knowledge: + # 向前端发送知识库搜索状态 + await event_emitter( + { + "type": "status", + "data": { + "action": "knowledge_search", + "query": user_message, + "done": False, + }, + } ) - ] - form_data, flags = await process_filter_functions( - request=request, - filter_functions=filter_functions, - filter_type="inlet", - form_data=form_data, - extra_params=extra_params, - ) - except Exception as e: - raise Exception(f"{e}") + knowledge_files = [] + for item in model_knowledge: + # 处理旧格式的 collection_name + if item.get("collection_name"): + knowledge_files.append( + { + "id": item.get("collection_name"), + "name": item.get("name"), + "legacy": True, + } + ) + # 处理新格式的 collection_names(多个集合) + elif item.get("collection_names"): + knowledge_files.append( + { + "name": item.get("name"), + "type": "collection", + "collection_names": item.get("collection_names"), + "legacy": True, + } + ) + else: + knowledge_files.append(item) + + # 合并模型知识库文件和用户上传文件 + files = form_data.get("files", []) + files.extend(knowledge_files) + form_data["files"] = files + variables = form_data.pop("variables", None) + + # === 10. Pipeline Inlet 处理 - 执行自定义 Python 插件 [已屏蔽] === + # Process the form_data through the pipeline + if False: + try: + form_data = await process_pipeline_inlet_filter( + request, form_data, user, models + ) + except Exception as e: + raise e + + # === 11. Filter Inlet 处理 - 执行函数过滤器 [已屏蔽] === + if False: + try: + filter_functions = [ + Functions.get_function_by_id(filter_id) + for filter_id in get_sorted_filter_ids( + request, model, metadata.get("filter_ids", []) + ) + ] + + form_data, flags = await process_filter_functions( + request=request, + filter_functions=filter_functions, + filter_type="inlet", + form_data=form_data, + extra_params=extra_params, + ) + except Exception as e: + raise Exception(f"{e}") # === 12. 功能增强处理 (Features) === features = form_data.pop("features", None) @@ -1225,28 +1337,31 @@ async def process_chat_payload(request, form_data, user, metadata, model): request, form_data, extra_params, user, metadata ) - # 12.2 网页搜索功能 - 执行网络搜索 - if "web_search" in features and features["web_search"]: - form_data = await chat_web_search_handler( - request, form_data, extra_params, user - ) + # 12.2 网页搜索功能 - 执行网络搜索 [已屏蔽] + if False: + if "web_search" in features and features["web_search"]: + form_data = await chat_web_search_handler( + request, form_data, extra_params, user + ) - # 12.3 图像生成功能 - 处理图像生成请求 - if "image_generation" in features and features["image_generation"]: - form_data = await chat_image_generation_handler( - request, form_data, extra_params, user - ) + # 12.3 图像生成功能 - 处理图像生成请求 [已屏蔽] + if False: + if "image_generation" in features and features["image_generation"]: + form_data = await chat_image_generation_handler( + request, form_data, extra_params, user + ) - # 12.4 代码解释器功能 - 注入代码执行提示词 - if "code_interpreter" in features and features["code_interpreter"]: - form_data["messages"] = add_or_update_user_message( - ( - request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE - if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != "" - else DEFAULT_CODE_INTERPRETER_PROMPT - ), - form_data["messages"], - ) + # 12.4 代码解释器功能 - 注入代码执行提示词 [已屏蔽] + if False: + if "code_interpreter" in features and features["code_interpreter"]: + form_data["messages"] = add_or_update_user_message( + ( + request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE + if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != "" + else DEFAULT_CODE_INTERPRETER_PROMPT + ), + form_data["messages"], + ) # === 13. 提取工具和文件信息 === tool_ids = form_data.pop("tool_ids", None) @@ -1282,7 +1397,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): } form_data["metadata"] = metadata - # === 16. 准备工具字典 === + # === 16. 准备工具字典 [部分屏蔽] === # Server side tools tool_ids = metadata.get("tool_ids", None) # 服务器端工具 ID # Client side tools @@ -1296,122 +1411,124 @@ async def process_chat_payload(request, form_data, user, metadata, model): mcp_clients = {} # MCP (Model Context Protocol) 客户端 mcp_tools_dict = {} # MCP 工具字典 - if tool_ids: - for tool_id in tool_ids: - # === 16.1 处理 MCP (Model Context Protocol) 工具 === - if tool_id.startswith("server:mcp:"): - try: - server_id = tool_id[len("server:mcp:") :] + # === 16.1-16.2 处理 MCP 工具和标准工具 [已屏蔽] === + if False: + if tool_ids: + for tool_id in tool_ids: + # === 16.1 处理 MCP (Model Context Protocol) 工具 === + if tool_id.startswith("server:mcp:"): + try: + server_id = tool_id[len("server:mcp:") :] - # 查找 MCP 服务器连接配置 - mcp_server_connection = None - for ( - server_connection - ) in request.app.state.config.TOOL_SERVER_CONNECTIONS: - if ( - server_connection.get("type", "") == "mcp" - and server_connection.get("info", {}).get("id") == server_id - ): - mcp_server_connection = server_connection - break + # 查找 MCP 服务器连接配置 + mcp_server_connection = None + for ( + server_connection + ) in request.app.state.config.TOOL_SERVER_CONNECTIONS: + if ( + server_connection.get("type", "") == "mcp" + and server_connection.get("info", {}).get("id") == server_id + ): + mcp_server_connection = server_connection + break - if not mcp_server_connection: - log.error(f"MCP server with id {server_id} not found") - continue + if not mcp_server_connection: + log.error(f"MCP server with id {server_id} not found") + continue - # 处理认证类型 - auth_type = mcp_server_connection.get("auth_type", "") + # 处理认证类型 + auth_type = mcp_server_connection.get("auth_type", "") - headers = {} - if auth_type == "bearer": - headers["Authorization"] = ( - f"Bearer {mcp_server_connection.get('key', '')}" - ) - elif auth_type == "none": - # 无需认证 - pass - elif auth_type == "session": - headers["Authorization"] = ( - f"Bearer {request.state.token.credentials}" - ) - elif auth_type == "system_oauth": - oauth_token = extra_params.get("__oauth_token__", None) - if oauth_token: + headers = {} + if auth_type == "bearer": headers["Authorization"] = ( - f"Bearer {oauth_token.get('access_token', '')}" + f"Bearer {mcp_server_connection.get('key', '')}" ) - elif auth_type == "oauth_2.1": - try: - splits = server_id.split(":") - server_id = splits[-1] if len(splits) > 1 else server_id - - oauth_token = await request.app.state.oauth_client_manager.get_oauth_token( - user.id, f"mcp:{server_id}" + elif auth_type == "none": + # 无需认证 + pass + elif auth_type == "session": + headers["Authorization"] = ( + f"Bearer {request.state.token.credentials}" ) - + elif auth_type == "system_oauth": + oauth_token = extra_params.get("__oauth_token__", None) if oauth_token: headers["Authorization"] = ( f"Bearer {oauth_token.get('access_token', '')}" ) - except Exception as e: - log.error(f"Error getting OAuth token: {e}") - oauth_token = None + elif auth_type == "oauth_2.1": + try: + splits = server_id.split(":") + server_id = splits[-1] if len(splits) > 1 else server_id - # 连接到 MCP 服务器 - mcp_clients[server_id] = MCPClient() - await mcp_clients[server_id].connect( - url=mcp_server_connection.get("url", ""), - headers=headers if headers else None, - ) - - # 获取 MCP 工具列表并注册 - tool_specs = await mcp_clients[server_id].list_tool_specs() - for tool_spec in tool_specs: - - def make_tool_function(client, function_name): - """为每个 MCP 工具创建异步调用函数""" - async def tool_function(**kwargs): - return await client.call_tool( - function_name, - function_args=kwargs, + oauth_token = await request.app.state.oauth_client_manager.get_oauth_token( + user.id, f"mcp:{server_id}" ) - return tool_function + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + oauth_token = None - tool_function = make_tool_function( - mcp_clients[server_id], tool_spec["name"] + # 连接到 MCP 服务器 + mcp_clients[server_id] = MCPClient() + await mcp_clients[server_id].connect( + url=mcp_server_connection.get("url", ""), + headers=headers if headers else None, ) - # 注册 MCP 工具到工具字典 - mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { - "spec": { - **tool_spec, - "name": f"{server_id}_{tool_spec['name']}", - }, - "callable": tool_function, - "type": "mcp", - "client": mcp_clients[server_id], - "direct": False, - } - except Exception as e: - log.debug(e) - continue + # 获取 MCP 工具列表并注册 + tool_specs = await mcp_clients[server_id].list_tool_specs() + for tool_spec in tool_specs: - # === 16.2 获取标准工具(Function Tools)=== - tools_dict = await get_tools( - request, - tool_ids, - user, - { - **extra_params, - "__model__": models[task_model_id], - "__messages__": form_data["messages"], - "__files__": metadata.get("files", []), - }, - ) - # 合并 MCP 工具 - if mcp_tools_dict: - tools_dict = {**tools_dict, **mcp_tools_dict} + def make_tool_function(client, function_name): + """为每个 MCP 工具创建异步调用函数""" + async def tool_function(**kwargs): + return await client.call_tool( + function_name, + function_args=kwargs, + ) + + return tool_function + + tool_function = make_tool_function( + mcp_clients[server_id], tool_spec["name"] + ) + + # 注册 MCP 工具到工具字典 + mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { + "spec": { + **tool_spec, + "name": f"{server_id}_{tool_spec['name']}", + }, + "callable": tool_function, + "type": "mcp", + "client": mcp_clients[server_id], + "direct": False, + } + except Exception as e: + log.debug(e) + continue + + # === 16.2 获取标准工具(Function Tools)=== + tools_dict = await get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models[task_model_id], + "__messages__": form_data["messages"], + "__files__": metadata.get("files", []), + }, + ) + # 合并 MCP 工具 + if mcp_tools_dict: + tools_dict = {**tools_dict, **mcp_tools_dict} # === 16.3 处理客户端直连工具服务器 === if direct_tool_servers: @@ -1429,79 +1546,82 @@ async def process_chat_payload(request, form_data, user, metadata, model): if mcp_clients: metadata["mcp_clients"] = mcp_clients - # === 17. 工具调用处理 === - if tools_dict: - if metadata.get("params", {}).get("function_calling") == "native": - # 原生函数调用模式:直接传递给 LLM - metadata["tools"] = tools_dict - form_data["tools"] = [ - {"type": "function", "function": tool.get("spec", {})} - for tool in tools_dict.values() - ] - else: - # 默认模式:通过 Prompt 实现工具调用 - try: - form_data, flags = await chat_completion_tools_handler( - request, form_data, extra_params, user, models, tools_dict - ) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - - # === 18. 文件处理 - RAG 检索 === - try: - form_data, flags = await chat_completion_files_handler( - request, form_data, extra_params, user - ) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - - # === 19. 构建上下文字符串并注入到消息 === - # If context is not empty, insert it into the messages - if len(sources) > 0: - context_string = "" - citation_idx_map = {} # 引用索引映射(文档 ID → 引用编号) - - # 遍历所有来源,构建上下文字符串 - for source in sources: - if "document" in source: - for document_text, document_metadata in zip( - source["document"], source["metadata"] - ): - source_name = source.get("source", {}).get("name", None) - source_id = ( - document_metadata.get("source", None) - or source.get("source", {}).get("id", None) - or "N/A" + # === 17. 工具调用处理 [已屏蔽] === + if False: + if tools_dict: + if metadata.get("params", {}).get("function_calling") == "native": + # 原生函数调用模式:直接传递给 LLM + metadata["tools"] = tools_dict + form_data["tools"] = [ + {"type": "function", "function": tool.get("spec", {})} + for tool in tools_dict.values() + ] + else: + # 默认模式:通过 Prompt 实现工具调用 + try: + form_data, flags = await chat_completion_tools_handler( + request, form_data, extra_params, user, models, tools_dict ) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) - # 为每个来源分配唯一的引用编号 - if source_id not in citation_idx_map: - citation_idx_map[source_id] = len(citation_idx_map) + 1 - - # 构建 XML 格式的来源标签 - context_string += ( - f'{document_text}\n" - ) - - context_string = context_string.strip() - if prompt is None: - raise Exception("No user message found") - - # 使用 RAG 模板将上下文注入到用户消息中 - if context_string != "": - form_data["messages"] = add_or_update_user_message( - rag_template( - request.app.state.config.RAG_TEMPLATE, - context_string, - prompt, - ), - form_data["messages"], - append=False, + # === 18. 文件处理 - RAG 检索 [已屏蔽] === + if False: + try: + form_data, flags = await chat_completion_files_handler( + request, form_data, extra_params, user ) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + + # === 19. 构建上下文字符串并注入到消息 [已屏蔽] === + # If context is not empty, insert it into the messages + + if len(sources) > 0: + context_string = "" + citation_idx_map = {} # 引用索引映射(文档 ID → 引用编号) + + # 遍历所有来源,构建上下文字符串 + for source in sources: + if "document" in source: + for document_text, document_metadata in zip( + source["document"], source["metadata"] + ): + source_name = source.get("source", {}).get("name", None) + source_id = ( + document_metadata.get("source", None) + or source.get("source", {}).get("id", None) + or "N/A" + ) + + # 为每个来源分配唯一的引用编号 + if source_id not in citation_idx_map: + citation_idx_map[source_id] = len(citation_idx_map) + 1 + + # 构建 XML 格式的来源标签 + context_string += ( + f'{document_text}\n" + ) + + context_string = context_string.strip() + if prompt is None: + raise Exception("No user message found") + + # 使用 RAG 模板将上下文注入到用户消息中 + if context_string != "": + form_data["messages"] = add_or_update_user_message( + rag_template( + request.app.state.config.RAG_TEMPLATE, + context_string, + prompt, + ), + form_data["messages"], + append=False, + ) # === 20. 整理引用来源并添加到事件 === # If there are citations, add them to the data_items @@ -1529,6 +1649,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): } ) + # print(form_data["messages"]) return form_data, metadata, events @@ -1560,137 +1681,201 @@ async def process_chat_response( - 流式: StreamingResponse (SSE/ndjson 格式) """ - # === 内部函数:后台任务处理器 === + # 记录上游 usage 供后台任务判断摘要阈值 + usage_holder = {"usage": None} + + # ======================================== + # 内部函数:后台任务处理器 + # ======================================== async def background_tasks_handler(): """ - 在响应完成后执行后台任务: - 1. Follow-ups 生成 - 生成后续问题建议 - 2. Title 生成 - 自动生成聊天标题 - 3. Tags 生成 - 自动生成聊天标签 - """ - message = None - messages = [] + 在响应完成后异步执行后台任务,增强用户体验 - # 获取消息历史 + 执行的任务类型: + 1. Follow-ups 生成 - 使用 LLM 生成 3-5 个后续问题建议,引导用户继续对话 + 2. Title 生成 - 基于对话内容自动生成聊天标题(首次对话时) + 3. Tags 生成 - 自动生成聊天分类标签(如"技术"、"工作"等) + + 数据流转: + - 输入:从数据库读取完整的消息历史(message_list)或从 form_data 获取临时消息 + - 处理:调用 LLM 生成任务结果(JSON 格式) + - 输出:通过 WebSocket 实时推送给前端 + 持久化到数据库 + + 边界情况: + - 临时聊天(chat_id 以 "local:" 开头):不生成标题和标签,仅生成 Follow-ups + - Follow-ups 可为临时聊天生成,但不持久化到数据库 + - 任务开关通过 tasks 字典控制(TASKS.TITLE_GENERATION 等) + """ + message = None # 当前 AI 回复消息(用于获取 model 字段) + messages = [] # 完整的消息历史列表(用于 LLM 生成任务) + + # ---------------------------------------- + # 第1步:获取消息历史 + # ---------------------------------------- + # 数据流转:从数据库或 form_data 获取消息列表 if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"): # 从数据库获取持久化的聊天历史 + # 数据结构:messages_map = {"message-id": {"role": "user", "content": "...", ...}} messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"]) message = messages_map.get(metadata["message_id"]) if messages_map else None + # 构建有序的消息链表(从 root 到当前 message_id) + # get_message_list 遍历 parentId 链接,返回完整的对话历史 message_list = get_message_list(messages_map, metadata["message_id"]) - # 清理消息内容:移除 details 标签和文件 - # get_message_list 创建新列表,不影响原始消息 + # ---------------------------------------- + # 第2步:清理消息内容 + # ---------------------------------------- + # 业务逻辑:移除对 LLM 生成任务无用的内容,降低 token 消耗 + # - 移除
折叠区域(通常是技术细节/日志) + # - 移除 Markdown 图片引用(![alt](url)) + # get_message_list 创建新列表,不影响数据库中的原始消息 messages = [] for message in message_list: content = message.get("content", "") - # 处理多模态内容(图片 + 文本) + + # 边界情况:处理多模态内容(OpenAI 格式:[{"type": "text", "text": "..."}, {"type": "image_url", ...}]) + # 从多模态消息中提取纯文本部分 if isinstance(content, list): for item in content: if item.get("type") == "text": content = item["text"] break - # 移除
标签和 Markdown 图片 + # 正则清理:移除对生成任务无用的内容 if isinstance(content, str): content = re.sub( - r"]*>.*?<\/details>|!\[.*?\]\(.*?\)", + r"]*>.*?<\/details>|!\[.*?\]\(.*?\)", #
...
或 ![](url) "", content, - flags=re.S | re.I, + flags=re.S | re.I, # re.S 使 . 匹配换行符,re.I 忽略大小写 ).strip() + # 构建清理后的消息对象 messages.append( { - **message, - "role": message.get("role", "assistant"), # 安全回退 - "content": content, + **message, # 保留原始字段(id, timestamp, parentId 等) + "role": message.get("role", "assistant"), # 边界情况:缺失 role 时安全回退 + "content": content, # 使用清理后的内容 } ) else: - # 临时聊天(local:):从 form_data 获取 + # 边界情况:临时聊天(chat_id 以 "local:" 开头) + # 数据流转:直接从请求 payload (form_data) 获取消息历史,无需查询数据库 message = get_last_user_message_item(form_data.get("messages", [])) messages = form_data.get("messages", []) if message: - message["model"] = form_data.get("model") + message["model"] = form_data.get("model") # 补充 model 字段(用于后续任务) - # 执行后台任务 + # ---------------------------------------- + # 第3步:执行后台任务 + # ---------------------------------------- + # 业务逻辑:仅当 message 有效且包含 model 字段时才执行 + # 边界情况:若 messages 为空或 tasks 未配置,则跳过所有任务 if message and "model" in message: if tasks and messages: - # === 任务 1: Follow-ups 生成 === - if ( - TASKS.FOLLOW_UP_GENERATION in tasks - and tasks[TASKS.FOLLOW_UP_GENERATION] - ): - res = await generate_follow_ups( - request, - { - "model": message["model"], - "messages": messages, - "message_id": metadata["message_id"], - "chat_id": metadata["chat_id"], - }, - user, - ) + # ======================================== + # 任务 1: Follow-ups 生成 + # [已屏蔽] + # ======================================== + # 业务逻辑:根据对话历史,使用 LLM 生成 3-5 个后续问题建议 + # 目的:引导用户继续深入对话,提升用户体验 + if False: + if ( + TASKS.FOLLOW_UP_GENERATION in tasks + and tasks[TASKS.FOLLOW_UP_GENERATION] + ): + # 调用 LLM 生成 Follow-ups + # 数据流转:messages → LLM → JSON 格式的 follow_ups 列表 + res = await generate_follow_ups( + request, + { + "model": message["model"], # 使用当前对话的模型 + "messages": messages, # 完整的清理后的消息历史 + "message_id": metadata["message_id"], + "chat_id": metadata["chat_id"], + }, + user, + ) - if res and isinstance(res, dict): - if len(res.get("choices", [])) == 1: - response_message = res.get("choices", [])[0].get( - "message", {} - ) - - follow_ups_string = response_message.get( - "content" - ) or response_message.get("reasoning_content", "") - else: - follow_ups_string = "" - - # 提取 JSON 对象(从第一个 { 到最后一个 }) - follow_ups_string = follow_ups_string[ - follow_ups_string.find("{") : follow_ups_string.rfind("}") - + 1 - ] - - try: - follow_ups = json.loads(follow_ups_string).get( - "follow_ups", [] - ) - # 通过 WebSocket 发送 Follow-ups - await event_emitter( - { - "type": "chat:message:follow_ups", - "data": { - "follow_ups": follow_ups, - }, - } - ) - - # 持久化到数据库 - if not metadata.get("chat_id", "").startswith("local:"): - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "followUps": follow_ups, - }, + # 边界情况:检查 LLM 响应是否有效 + if res and isinstance(res, dict): + if len(res.get("choices", [])) == 1: + response_message = res.get("choices", [])[0].get( + "message", {} ) - except Exception as e: - pass + # 提取内容(优先 content,回退到 reasoning_content) + follow_ups_string = response_message.get( + "content" + ) or response_message.get("reasoning_content", "") + else: + # 边界情况:LLM 返回多个 choices 或没有 choices + follow_ups_string = "" - # === 任务 2 & 3: 标题和标签生成(仅非临时聊天)=== + # 数据清理:提取 JSON 对象(从第一个 { 到最后一个 }) + # 业务逻辑:LLM 可能在 JSON 前后添加说明文字,需要裁剪 + follow_ups_string = follow_ups_string[ + follow_ups_string.find("{") : follow_ups_string.rfind("}") + + 1 + ] + + try: + # 解析 JSON:{"follow_ups": ["问题1", "问题2", "问题3"]} + follow_ups = json.loads(follow_ups_string).get( + "follow_ups", [] + ) + + # 数据流转:通过 WebSocket 实时推送给前端 + await event_emitter( + { + "type": "chat:message:follow_ups", + "data": { + "follow_ups": follow_ups, + }, + } + ) + + # 数据流转:持久化到数据库(仅非临时聊天) + # 边界情况:临时聊天(local:)不持久化 + if not metadata.get("chat_id", "").startswith("local:"): + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "followUps": follow_ups, + }, + ) + + except Exception as e: + # 边界情况:JSON 解析失败(LLM 返回格式错误) + # 静默失败,不影响主流程 + pass + + # ======================================== + # 任务 2 & 3: 标题和标签生成(仅非临时聊天) + # ======================================== + # 边界情况:临时聊天(local:)不需要标题和标签,跳过这两个任务 if not metadata.get("chat_id", "").startswith("local:"): + # ======================================== # 任务 2: 标题生成 + # ======================================== + # 业务逻辑:自动生成聊天标题,提升用户体验(避免显示 "New Chat") + # 触发时机:首次对话完成后 if ( TASKS.TITLE_GENERATION in tasks and tasks[TASKS.TITLE_GENERATION] ): + # 获取最后一条用户消息作为回退标题 user_message = get_last_user_message(messages) if user_message and len(user_message) > 100: + # 边界情况:截断过长的消息(避免标题过长) user_message = user_message[:100] + "..." if tasks[TASKS.TITLE_GENERATION]: + # 调用 LLM 生成标题 + # 数据流转:messages → LLM → JSON 格式的 title 字符串 res = await generate_title( request, { @@ -1701,12 +1886,15 @@ async def process_chat_response( user, ) + # 边界情况:检查 LLM 响应是否有效 if res and isinstance(res, dict): if len(res.get("choices", [])) == 1: response_message = res.get("choices", [])[0].get( "message", {} ) + # 提取内容(多层回退策略) + # 优先级:content > reasoning_content > 当前 AI 回复 > 用户消息 title_string = ( response_message.get("content") or response_message.get( @@ -1715,36 +1903,41 @@ async def process_chat_response( or message.get("content", user_message) ) else: + # 边界情况:LLM 返回多个 choices 或没有 choices title_string = "" - # 提取 JSON 对象 + # 数据清理:提取 JSON 对象 title_string = title_string[ title_string.find("{") : title_string.rfind("}") + 1 ] try: + # 解析 JSON:{"title": "生成的标题"} title = json.loads(title_string).get( "title", user_message ) except Exception as e: + # 边界情况:JSON 解析失败 title = "" + # 边界情况:标题为空时,使用首条用户消息作为回退 if not title: title = messages[0].get("content", user_message) - # 更新数据库 + # 数据流转:更新数据库 Chats.update_chat_title_by_id( metadata["chat_id"], title ) - # 通过 WebSocket 发送标题 + # 数据流转:通过 WebSocket 发送标题给前端 await event_emitter( { "type": "chat:title", "data": title, } ) - # 如果只有 2 条消息(首次对话),直接用用户消息作为标题 + # 边界情况:简单对话(仅 2 条消息:1条 user + 1条 assistant) + # 业务逻辑:直接用用户消息作为标题,无需调用 LLM(节省成本) elif len(messages) == 2: title = messages[0].get("content", user_message) @@ -1757,8 +1950,14 @@ async def process_chat_response( } ) + # ======================================== # 任务 3: 标签生成 + # ======================================== + # 业务逻辑:使用 LLM 生成聊天分类标签(如"技术"、"工作"、"生活"等) + # 目的:方便用户对聊天进行分类管理和检索 if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: + # 调用 LLM 生成标签 + # 数据流转:messages → LLM → JSON 格式的 tags 数组 res = await generate_chat_tags( request, { @@ -1769,31 +1968,36 @@ async def process_chat_response( user, ) + # 边界情况:检查 LLM 响应是否有效 if res and isinstance(res, dict): if len(res.get("choices", [])) == 1: response_message = res.get("choices", [])[0].get( "message", {} ) + # 提取内容(优先 content,回退到 reasoning_content) tags_string = response_message.get( "content" ) or response_message.get("reasoning_content", "") else: + # 边界情况:LLM 返回多个 choices 或没有 choices tags_string = "" - # 提取 JSON 对象 + # 数据清理:提取 JSON 对象 tags_string = tags_string[ tags_string.find("{") : tags_string.rfind("}") + 1 ] try: + # 解析 JSON:{"tags": ["技术", "工作", "Python"]} tags = json.loads(tags_string).get("tags", []) - # 更新数据库 + + # 数据流转:更新数据库(保存到 chat.meta.tags) Chats.update_chat_tags_by_id( metadata["chat_id"], tags, user ) - # 通过 WebSocket 发送标签 + # 数据流转:通过 WebSocket 发送标签给前端 await event_emitter( { "type": "chat:tags", @@ -1801,9 +2005,98 @@ async def process_chat_response( } ) except Exception as e: + # 边界情况:JSON 解析失败 + # 静默失败,不影响主流程 pass - # === 1. 解析事件发射器(仅在有实时会话时使用)=== + # ======================================== + # 任务 4: 摘要更新(基于阈值) + # ======================================== + try: + # 获取 chat_id + chat_id = metadata.get("chat_id") + if chat_id and not str(chat_id).startswith("local:"): + # 阈值来自配置或全局开关 + threshold = getattr( + request.app.state.config, + "SUMMARY_TOKEN_THRESHOLD", + SUMMARY_TOKEN_THRESHOLD_DEFAULT, + ) + + # 优先使用上游 usage,再回退自算 token + tokens = None + if isinstance(usage_holder.get("usage"), dict): + usage = usage_holder["usage"] + tokens = usage.get("total_tokens") or usage.get( + "prompt_tokens" + ) + if tokens is None: + tokens = compute_token_count(messages or []) + + # 若超过阈值 + if tokens >= threshold: + if CHAT_DEBUG_FLAG: + print( + f"[summary:update] chat_id={chat_id} token数={tokens} 阈值={threshold}" + ) + + # 读取已有的 summary + existing_summary = Chats.get_summary_by_user_id_and_chat_id( + user.id, chat_id + ) + old_summary = ( + existing_summary.get("content") + if existing_summary + else None + ) + + # 取摘要边界前 20 条 + 之后所有消息 + messages_map = Chats.get_messages_map_by_chat_id(chat_id) or {} + ordered = slice_messages_with_summary( + messages_map, + existing_summary.get("last_message_id") if existing_summary else None, + metadata.get("message_id"), + pre_boundary=20, + ) + + summary_messages = [ + msg + for msg in ordered + if msg.get("role") in ("user", "assistant") + ] + if CHAT_DEBUG_FLAG: + print( + f"[summary:update] chat_id={chat_id} 切片总数={len(ordered)} 摘要参与消息数={len(summary_messages)} " + f"上次摘要 message_id={existing_summary.get('last_message_id') if existing_summary else None}" + ) + summary_text = summarize(summary_messages, old_summary) + last_msg_id = ( + summary_messages[-1].get("id") + if summary_messages + else metadata.get("message_id") + ) + Chats.set_summary_by_user_id_and_chat_id( + user.id, + chat_id, + summary_text, + last_msg_id, + int(time.time()), + recent_message_ids=[], + ) + else: + if CHAT_DEBUG_FLAG: + print( + f"[summary:update] chat_id={chat_id} token数={tokens} 低于阈值={threshold}" + ) + except Exception as e: + log.debug(f"summary update skipped: {e}") + + # ======================================== + # 第一阶段:事件发射器初始化 + # ======================================== + # 业务逻辑:仅在异步模式(有 session_id)时初始化 WebSocket 事件发射器 + # 数据流转:通过 WebSocket 向前端推送实时事件(completion、error、title 等) + # 边界情况:同步模式(无 session_id)时,event_emitter 和 event_caller 为 None event_emitter = None event_caller = None if ( @@ -1814,43 +2107,66 @@ async def process_chat_response( and "message_id" in metadata and metadata["message_id"] ): - event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器 - event_caller = get_event_call(metadata) # 事件调用器 + # 获取 WebSocket 事件发射器(用于向前端推送事件) + # event_emitter: async def (event: dict) -> None + event_emitter = get_event_emitter(metadata) + # 获取事件调用器(用于调用 MCP/工具等) + # event_caller: async def (event: dict) -> Any + event_caller = get_event_call(metadata) - # === 2. 非流式响应处理 === + # ======================================== + # 第二阶段:非流式响应处理 + # ======================================== + # 业务逻辑:处理 LLM 返回的完整 JSON 响应(非 SSE 流) + # 数据流转:response (dict/JSONResponse) → 解析 → WebSocket 推送 + 数据库持久化 + # 触发条件:response 不是 StreamingResponse 实例 if not isinstance(response, StreamingResponse): + # 边界情况:仅在异步模式(有 event_emitter)时才执行 WebSocket 推送和数据库写入 if event_emitter: try: + # ---------------------------------------- + # 步骤 1:响应类型检查和数据解析 + # ---------------------------------------- + # 边界情况:支持 dict 和 JSONResponse 两种响应类型 if isinstance(response, dict) or isinstance(response, JSONResponse): - # 处理单项列表(解包) + # 边界情况:处理单项列表(某些 LLM 可能返回 [response]) + # 数据流转:[response] → response if isinstance(response, list) and len(response) == 1: response = response[0] - # 解析 JSONResponse 的 body + # 边界情况:JSONResponse 需要从 body (bytes) 解析 JSON + # 数据流转:JSONResponse.body (bytes) → response_data (dict) if isinstance(response, JSONResponse) and isinstance( response.body, bytes ): try: response_data = json.loads( - response.body.decode("utf-8", "replace") + response.body.decode("utf-8", "replace") # replace 处理无效 UTF-8 字符 ) except json.JSONDecodeError: + # 边界情况:JSON 解析失败,构造错误响应 response_data = { "error": {"detail": "Invalid JSON response"} } else: + # dict 类型直接使用 response_data = response - # 处理错误响应 + # ---------------------------------------- + # 步骤 2:错误响应处理 + # ---------------------------------------- + # 业务逻辑:LLM 返回错误(如 API 限流、模型不可用等) + # 数据流转:error → 数据库 + WebSocket 推送 if "error" in response_data: error = response_data.get("error") + # 边界情况:统一错误格式(dict 或 str) if isinstance(error, dict): - error = error.get("detail", error) + error = error.get("detail", error) # 提取 detail 字段或保持原样 else: error = str(error) - # 保存错误到数据库 + # 数据流转:保存错误到数据库(message.error 字段) Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], @@ -1858,7 +2174,8 @@ async def process_chat_response( "error": {"content": error}, }, ) - # 通过 WebSocket 发送错误事件 + + # 数据流转:通过 WebSocket 实时推送错误事件给前端 if isinstance(error, str) or isinstance(error, dict): await event_emitter( { @@ -1867,6 +2184,11 @@ async def process_chat_response( } ) + # ---------------------------------------- + # 步骤 3:Arena 模式处理(盲测模型选择) + # ---------------------------------------- + # 业务逻辑:Arena 模式下,LLM 随机选择模型,返回 selected_model_id + # 数据流转:selected_model_id → 数据库(message.selectedModelId 字段) if "selected_model_id" in response_data: Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], @@ -1876,20 +2198,29 @@ async def process_chat_response( }, ) + # ---------------------------------------- + # 步骤 4:成功响应处理 + # ---------------------------------------- + # 业务逻辑:LLM 正常返回完整响应 + # 数据流转:content → WebSocket 推送 + 数据库 + Webhook 通知 + 后台任务 choices = response_data.get("choices", []) if choices and choices[0].get("message", {}).get("content"): content = response_data["choices"][0]["message"]["content"] if content: + # 数据流转:第1次 WebSocket 推送 - 完整的 LLM 响应数据 await event_emitter( { "type": "chat:completion", - "data": response_data, + "data": response_data, # 包含 choices、usage、model 等完整字段 } ) + # 获取聊天标题(用于 Webhook 通知和第2次推送) title = Chats.get_chat_title_by_id(metadata["chat_id"]) + # 数据流转:第2次 WebSocket 推送 - 标记完成并附带标题 + # 业务逻辑:前端收到 done=True 后停止加载动画 await event_emitter( { "type": "chat:completion", @@ -1901,7 +2232,8 @@ async def process_chat_response( } ) - # Save message in the database + # 数据流转:保存 AI 回复到数据库 + # 字段:role="assistant", content(完整回复内容) Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], @@ -1911,14 +2243,20 @@ async def process_chat_response( }, ) - # Send a webhook notification if the user is not active + # ---------------------------------------- + # 步骤 5:Webhook 通知(用户离线时) + # ---------------------------------------- + # 业务逻辑:用户不在线时,通过 Webhook 发送通知(如 Slack、Discord、企业微信等) + # 边界情况:仅当用户配置了 webhook_url 且当前不在线时触发 if not get_active_status_by_user_id(user.id): webhook_url = Users.get_user_webhook_url_by_id(user.id) if webhook_url: await post_webhook( - request.app.state.WEBUI_NAME, - webhook_url, + request.app.state.WEBUI_NAME, # 应用名称(如 "Open WebUI") + webhook_url, # 用户配置的 Webhook URL + # 消息正文:标题 + 链接 + 内容 f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", + # 结构化数据(供 Webhook 接收方解析) { "action": "chat", "message": content, @@ -1927,21 +2265,37 @@ async def process_chat_response( }, ) + # ---------------------------------------- + # 步骤 6:执行后台任务 + # ---------------------------------------- + # 业务逻辑:异步生成 Follow-ups、标题、标签 await background_tasks_handler() + # ---------------------------------------- + # 步骤 7:合并额外事件(如 RAG sources) + # ---------------------------------------- + # 业务逻辑:process_chat_payload 可能注入额外事件(如 RAG 检索的 sources) + # 数据流转:events (list) → extra_response (dict) → 合并到 response_data if events and isinstance(events, list): extra_response = {} for event in events: if isinstance(event, dict): + # dict 事件:合并到 extra_response extra_response.update(event) else: + # 字符串事件:设置为 True(如 "web_search": True) extra_response[event] = True + # 合并策略:extra_response 在前(优先级低),response_data 在后(优先级高) response_data = { **extra_response, **response_data, } + # ---------------------------------------- + # 步骤 8:重新封装响应对象 + # ---------------------------------------- + # 业务逻辑:保持响应类型一致(dict 或 JSONResponse) if isinstance(response, dict): response = response_data if isinstance(response, JSONResponse): @@ -1952,11 +2306,19 @@ async def process_chat_response( ) except Exception as e: + # 边界情况:捕获所有异常,避免整个响应处理流程崩溃 + # 业务逻辑:记录日志但不中断(静默失败) log.debug(f"Error occurred while processing request: {e}") pass + # 返回处理后的响应(已合并 events 并更新数据库) return response else: + # ---------------------------------------- + # 同步模式:无 event_emitter(无 session_id) + # ---------------------------------------- + # 业务逻辑:直接返回响应,不执行 WebSocket 推送和数据库写入 + # 边界情况:仅合并 events 到响应中 if events and isinstance(events, list) and isinstance(response, dict): extra_response = {} for event in events: @@ -1965,6 +2327,7 @@ async def process_chat_response( else: extra_response[event] = True + # 合并 events 到响应 response = { **extra_response, **response, @@ -1972,13 +2335,23 @@ async def process_chat_response( return response - # Non standard response (not SSE/ndjson) + # ======================================== + # 第三阶段:流式响应前置检查 + # ======================================== + # 边界情况:非标准流式响应(既不是 SSE 也不是 ndjson) + # 业务逻辑:直接返回原始响应,不进行流式处理 if not any( content_type in response.headers["Content-Type"] for content_type in ["text/event-stream", "application/x-ndjson"] ): return response + # ---------------------------------------- + # 步骤 1:OAuth Token 获取 + # ---------------------------------------- + # 业务逻辑:如果用户通过 OAuth 登录(如 Google、GitHub),获取 OAuth access token + # 用途:传递给 Filter 函数和工具(可能需要调用外部 API) + # 边界情况:OAuth 获取失败时,oauth_token 为 None(不影响主流程) oauth_token = None try: if request.cookies.get("oauth_session_id", None): @@ -1987,17 +2360,30 @@ async def process_chat_response( request.cookies.get("oauth_session_id", None), ) except Exception as e: + # 边界情况:OAuth 获取失败(如 session 过期),记录日志但不中断 log.error(f"Error getting OAuth token: {e}") + # ---------------------------------------- + # 步骤 2:准备额外参数(传递给 Filter 函数和工具) + # ---------------------------------------- + # 业务逻辑:构造通用参数对象,供后续 Filter 函数和工具使用 + # 数据流转:extra_params → Filter 函数 → 工具调用 extra_params = { - "__event_emitter__": event_emitter, - "__event_call__": event_caller, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__oauth_token__": oauth_token, - "__request__": request, - "__model__": model, + "__event_emitter__": event_emitter, # WebSocket 事件发射器 + "__event_call__": event_caller, # 事件调用器(MCP/工具) + "__user__": user.model_dump() if isinstance(user, UserModel) else {}, # 用户信息 + "__metadata__": metadata, # 聊天上下文(chat_id、message_id 等) + "__oauth_token__": oauth_token, # OAuth access token(可能为 None) + "__request__": request, # FastAPI Request 对象 + "__model__": model, # 模型配置对象 } + + # ---------------------------------------- + # 步骤 3:加载 Filter 函数 + # ---------------------------------------- + # 业务逻辑:Filter 函数可以拦截/修改流式响应的每个 delta + # 数据流转:filter_ids → 从数据库加载 Filter 函数对象 → 按优先级排序 + # 用途:内容审核、敏感词过滤、格式转换等 filter_functions = [ Functions.get_function_by_id(filter_id) for filter_id in get_sorted_filter_ids( @@ -2005,11 +2391,23 @@ async def process_chat_response( ) ] - # Streaming response: consume upstream SSE/ndjson, forward deltas/events, handle tools + # ======================================== + # 第四阶段:流式响应处理 + # ======================================== + # 业务逻辑:消费上游 SSE/ndjson 流,逐段处理并转发给前端 + # 触发条件:异步模式(有 event_emitter 和 event_caller) + # 数据流转:上游 SSE 流 → 解析 delta → Filter 处理 → 累积 content → WebSocket 推送 → 数据库持久化 if event_emitter and event_caller: - task_id = str(uuid4()) # Create a unique task ID. + # 生成唯一任务 ID(用于任务追踪和取消) + task_id = str(uuid4()) model_id = form_data.get("model", "") + # ======================================== + # 辅助函数 1:split_content_and_whitespace + # ======================================== + # 业务逻辑:分离内容和尾部空白符(用于流式推送优化) + # 目的:避免在代码块未闭合时过早推送,造成显示异常 + # 数据流转:content → (content_stripped, original_whitespace) def split_content_and_whitespace(content): content_stripped = content.rstrip() original_whitespace = ( @@ -2019,33 +2417,64 @@ async def process_chat_response( ) return content_stripped, original_whitespace + # ======================================== + # 辅助函数 2:is_opening_code_block + # ======================================== + # 业务逻辑:检测内容是否以未闭合的代码块结尾 + # 原理:计算 ``` 的数量,偶数个 segment 表示最后一个 ``` 是开启新代码块 + # 边界情况:用于判断是否应该延迟推送(等待代码块闭合) def is_opening_code_block(content): backtick_segments = content.split("```") - # Even number of segments means the last backticks are opening a new block + # 偶数个 segment 意味着最后一个 ``` 正在开启新的代码块 return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 - # Handle as a background task + # ======================================== + # 响应处理器(后台任务) + # ======================================== + # 业务逻辑:在后台异步处理流式响应,避免阻塞主线程 + # 数据流转:response (StreamingResponse) → 消费 SSE 流 → 解析 delta → 累积 content → WebSocket 推送 + 数据库 async def response_handler(response, events): + # ======================================== + # 辅助函数 3:serialize_content_blocks + # ======================================== + # 业务逻辑:将内容块数组序列化为可读字符串(用于前端显示和数据库存储) + # 数据流转:content_blocks (list) → content (string) + # 参数: + # - content_blocks: 内容块数组 [{"type": "text", "content": "..."}, {"type": "tool_calls", ...}] + # - raw: 是否原始格式(True=保留标签,False=转换为 HTML details 折叠区域) def serialize_content_blocks(content_blocks, raw=False): content = "" for block in content_blocks: + # ---------------------------------------- + # 类型 1:普通文本块 + # ---------------------------------------- + # 业务逻辑:直接拼接文本内容,每块后添加换行符 if block["type"] == "text": block_content = block["content"].strip() if block_content: content = f"{content}{block_content}\n" + + # ---------------------------------------- + # 类型 2:工具调用块 + # ---------------------------------------- + # 业务逻辑:将工具调用及其结果渲染为 HTML details 折叠区域 + # 数据流转:tool_calls + results →
HTML 标签 elif block["type"] == "tool_calls": attributes = block.get("attributes", {}) - tool_calls = block.get("content", []) - results = block.get("results", []) + tool_calls = block.get("content", []) # 工具调用列表 [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}] + results = block.get("results", []) # 工具执行结果列表 [{"tool_call_id": "call_1", "content": "..."}] + # 确保前面有换行符(格式美化) if content and not content.endswith("\n"): content += "\n" + # ========== 分支 1:工具已执行(有 results)========== if results: tool_calls_display_content = "" + # 遍历每个工具调用,匹配其结果 for tool_call in tool_calls: tool_call_id = tool_call.get("id", "") @@ -2056,6 +2485,7 @@ async def process_chat_response( "arguments", "" ) + # 查找对应的工具结果 tool_result = None tool_result_files = None for result in results: @@ -2064,17 +2494,25 @@ async def process_chat_response( tool_result_files = result.get("files", None) break + # 渲染工具调用结果 if tool_result is not None: + # 工具执行成功:done="true" tool_result_embeds = result.get("embeds", "") + # HTML 转义:防止 XSS 攻击 tool_calls_display_content = f'{tool_calls_display_content}
\nTool Executed\n
\n' else: + # 工具执行中或失败:done="false" tool_calls_display_content = f'{tool_calls_display_content}
\nExecuting...\n
\n' + # raw=False 时才拼接到 content(raw 模式跳过工具调用) if not raw: content = f"{content}{tool_calls_display_content}" + + # ========== 分支 2:工具未执行(无 results)========== else: tool_calls_display_content = "" + # 渲染所有工具调用为"执行中"状态 for tool_call in tool_calls: tool_call_id = tool_call.get("id", "") tool_name = tool_call.get("function", {}).get( @@ -2089,104 +2527,156 @@ async def process_chat_response( if not raw: content = f"{content}{tool_calls_display_content}" + # ---------------------------------------- + # 类型 3:推理内容块(Reasoning) + # ---------------------------------------- + # 业务逻辑:渲染 LLM 的思考过程(如 o1 模型的 标签内容) + # 数据流转:推理内容 → Markdown 引用格式(> 前缀)→
折叠区域 elif block["type"] == "reasoning": + # 格式化推理内容:每行前添加 > 前缀(Markdown 引用格式) reasoning_display_content = "\n".join( (f"> {line}" if not line.startswith(">") else line) for line in block["content"].splitlines() ) - reasoning_duration = block.get("duration", None) + reasoning_duration = block.get("duration", None) # 推理耗时(秒) - start_tag = block.get("start_tag", "") - end_tag = block.get("end_tag", "") + start_tag = block.get("start_tag", "") # 原始标签(如 ) + end_tag = block.get("end_tag", "") # 原始结束标签(如 ) + # 确保前面有换行符 if content and not content.endswith("\n"): content += "\n" + # 分支 1:推理完成(有 duration) if reasoning_duration is not None: if raw: + # raw 模式:保留原始标签 content = ( f'{content}{start_tag}{block["content"]}{end_tag}\n' ) else: + # 标准模式:渲染为折叠区域,显示推理耗时 content = f'{content}
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' + + # 分支 2:推理进行中(无 duration) else: if raw: + # raw 模式:保留原始标签 content = ( f'{content}{start_tag}{block["content"]}{end_tag}\n' ) else: + # 标准模式:渲染为"思考中"状态 content = f'{content}
\nThinking…\n{reasoning_display_content}\n
\n' + # ---------------------------------------- + # 类型 4:代码解释器块(Code Interpreter) + # ---------------------------------------- + # 业务逻辑:渲染代码执行及其输出结果 + # 数据流转:代码 + 输出 → Markdown 代码块 +
折叠区域 elif block["type"] == "code_interpreter": attributes = block.get("attributes", {}) - output = block.get("output", None) - lang = attributes.get("lang", "") + output = block.get("output", None) # 代码执行输出 + lang = attributes.get("lang", "") # 编程语言(如 python) + # 检测并处理未闭合的代码块 + # 业务逻辑:避免在 LLM 正在生成代码块时过早插入代码解释器块 content_stripped, original_whitespace = ( split_content_and_whitespace(content) ) if is_opening_code_block(content_stripped): - # Remove trailing backticks that would open a new block + # 移除尾部的 ``` (正在开启新代码块) + # 边界情况:防止出现 ``` 连续符号导致渲染错误 content = ( content_stripped.rstrip("`").rstrip() + original_whitespace ) else: - # Keep content as is - either closing backticks or no backticks + # 保持内容不变(代码块已闭合或无代码块) content = content_stripped + original_whitespace + # 确保前面有换行符 if content and not content.endswith("\n"): content += "\n" + # 分支 1:代码已执行(有 output) if output: + # HTML 转义:防止 XSS 攻击 output = html.escape(json.dumps(output)) if raw: + # raw 模式:使用自定义标签 content = f'{content}\n{block["content"]}\n\n```output\n{output}\n```\n' else: + # 标准模式:渲染为折叠区域,显示代码和输出 content = f'{content}
\nAnalyzed\n```{lang}\n{block["content"]}\n```\n
\n' + + # 分支 2:代码执行中(无 output) else: if raw: + # raw 模式:使用自定义标签 content = f'{content}\n{block["content"]}\n\n' else: + # 标准模式:渲染为"分析中"状态 content = f'{content}
\nAnalyzing...\n```{lang}\n{block["content"]}\n```\n
\n' + # ---------------------------------------- + # 类型 5:未知类型块(回退处理) + # ---------------------------------------- + # 边界情况:处理自定义或未来新增的内容块类型 else: block_content = str(block["content"]).strip() if block_content: content = f"{content}{block['type']}: {block_content}\n" + # 返回序列化后的字符串(移除首尾空白符) return content.strip() + # ======================================== + # 辅助函数 4:convert_content_blocks_to_messages + # ======================================== + # 业务逻辑:将内容块数组转换为 OpenAI 格式的消息列表 + # 数据流转:content_blocks (list) → messages (list) + # 用途:工具调用迭代时,需要将历史对话转换为 OpenAI messages 格式 + # 格式示例: + # 输入:[{"type": "text", "content": "..."}, {"type": "tool_calls", "content": [...], "results": [...]}] + # 输出:[{"role": "assistant", "content": "..."}, {"role": "tool", "tool_call_id": "...", "content": "..."}] def convert_content_blocks_to_messages(content_blocks, raw=False): messages = [] + # 临时累积非工具调用块 temp_blocks = [] for idx, block in enumerate(content_blocks): + # 遇到工具调用块:刷新临时块 + 添加工具调用消息 + 添加工具结果消息 if block["type"] == "tool_calls": + # 1. 将临时累积的块序列化为 assistant 消息 messages.append( { "role": "assistant", "content": serialize_content_blocks(temp_blocks, raw), - "tool_calls": block.get("content"), + "tool_calls": block.get("content"), # 工具调用列表 } ) + # 2. 将每个工具结果转换为 tool 消息 results = block.get("results", []) - for result in results: messages.append( { "role": "tool", - "tool_call_id": result["tool_call_id"], + "tool_call_id": result["tool_call_id"], # 关联到对应的工具调用 "content": result.get("content", "") or "", } ) + + # 重置临时块(开始累积下一段内容) temp_blocks = [] else: + # 非工具调用块:累积到临时块 temp_blocks.append(block) + # 处理剩余的临时块(最后一段内容) if temp_blocks: content = serialize_content_blocks(temp_blocks, raw) if content: @@ -2199,11 +2689,23 @@ async def process_chat_response( return messages + # ======================================== + # 辅助函数 5:tag_content_handler + # ======================================== + # 业务逻辑:检测并处理流式内容中的特殊标签(如 ) + # 数据流转:content (string) + content_blocks (list) → 检测标签 → 分割/创建新内容块 + # 参数: + # - content_type: 标签类型("reasoning"、"solution"、"code_interpreter") + # - tags: 标签对列表 [[start_tag, end_tag], ...] + # - content: 当前累积的内容字符串 + # - content_blocks: 内容块数组(会被修改) + # 返回:(content, content_blocks, end_flag) + # - end_flag: 是否检测到结束标签(True=该类型内容块已结束) def tag_content_handler(content_type, tags, content, content_blocks): end_flag = False def extract_attributes(tag_content): - """Extract attributes from a tag if they exist.""" + """从标签中提取属性(如 )""" attributes = {} if not tag_content: # Ensure tag_content is not None return attributes @@ -2555,6 +3057,7 @@ async def process_chat_response( usage = data.get("usage", {}) or {} usage.update(data.get("timings", {})) # llama.cpp timings if usage: + usage_holder["usage"] = usage await event_emitter( { "type": "chat:completion", diff --git a/backend/open_webui/utils/summary.py b/backend/open_webui/utils/summary.py new file mode 100644 index 0000000000..9ac493a1b3 --- /dev/null +++ b/backend/open_webui/utils/summary.py @@ -0,0 +1,172 @@ +from typing import Dict, List, Optional, Tuple + +from open_webui.models.chats import Chats + + +def build_ordered_messages( + messages_map: Optional[Dict], anchor_id: Optional[str] = None +) -> List[Dict]: + """ + 将消息 map 还原为有序列表 + + 策略: + 1. 优先:基于 parentId 链条追溯(从 anchor_id 向上回溯到根消息) + 2. 退化:按时间戳排序(无 anchor_id 或追溯失败时) + + 参数: + messages_map: 消息 map,格式 {"msg-id": {"role": "user", "content": "...", "parentId": "...", "timestamp": 123456}} + anchor_id: 锚点消息 ID(链尾),从此消息向上追溯 + + 返回: + 有序的消息列表,每个消息包含 id 字段 + """ + if not messages_map: + return [] + + # 补齐消息的 id 字段 + def with_id(message_id: str, message: Dict) -> Dict: + return {**message, **({"id": message_id} if "id" not in message else {})} + + # 模式 1:基于 parentId 链条追溯 + if anchor_id and anchor_id in messages_map: + ordered: List[Dict] = [] + current_id: Optional[str] = anchor_id + + while current_id: + current_msg = messages_map.get(current_id) + if not current_msg: + break + ordered.insert(0, with_id(current_id, current_msg)) + current_id = current_msg.get("parentId") + + return ordered + + # 模式 2:基于时间戳排序 + sortable: List[Tuple[int, str, Dict]] = [] + for mid, message in messages_map.items(): + ts = ( + message.get("createdAt") + or message.get("created_at") + or message.get("timestamp") + or 0 + ) + sortable.append((int(ts), mid, message)) + + sortable.sort(key=lambda x: x[0]) + return [with_id(mid, msg) for _, mid, msg in sortable] + + +def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]: + """ + 获取指定用户的全局最近 N 条消息(按时间顺序) + + 参数: + user_id: 用户 ID + num: 需要获取的消息数量(<= 0 时返回全部) + + 返回: + 有序的消息列表(最近的 num 条) + """ + all_messages: List[Dict] = [] + + # 遍历用户的所有聊天 + chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True) + for chat in chats: + messages_map = chat.chat.get("history", {}).get("messages", {}) or {} + for mid, msg in messages_map.items(): + # 跳过空内容 + if msg.get("content", "") == "": + continue + ts = ( + msg.get("createdAt") + or msg.get("created_at") + or msg.get("timestamp") + or 0 + ) + entry = {**msg, "id": mid} + entry.setdefault("chat_id", chat.id) + entry.setdefault("timestamp", int(ts)) + all_messages.append(entry) + + # 按时间戳排序 + all_messages.sort(key=lambda m: m.get("timestamp", 0)) + + if num <= 0: + return all_messages + + return all_messages[-num:] + + +def slice_messages_with_summary( + messages_map: Dict, + boundary_message_id: Optional[str], + anchor_id: Optional[str], + pre_boundary: int = 20, +) -> List[Dict]: + """ + 基于摘要边界裁剪消息列表(返回摘要前 N 条 + 摘要后全部消息) + + 策略:保留摘要边界前 N 条消息(提供上下文)+ 摘要后全部消息(最新对话) + 目的:降低 token 消耗,同时保留足够的上下文信息 + + 参数: + messages_map: 消息 map + boundary_message_id: 摘要边界消息 ID(None 时返回全量消息) + anchor_id: 锚点消息 ID(链尾) + pre_boundary: 摘要边界前保留的消息数量(默认 20) + + 返回: + 裁剪后的有序消息列表 + + 示例: + 100 条消息,摘要边界在第 50 条,pre_boundary=20 + → 返回消息 29-99(共 71 条) + """ + ordered = build_ordered_messages(messages_map, anchor_id) + + if boundary_message_id: + try: + # 查找摘要边界消息的索引 + boundary_idx = next( + idx for idx, msg in enumerate(ordered) if msg.get("id") == boundary_message_id + ) + # 计算裁剪起点 + start_idx = max(boundary_idx - pre_boundary, 0) + ordered = ordered[start_idx:] + except StopIteration: + # 边界消息不存在,返回全量 + pass + + return ordered + + +def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str: + """ + 生成对话摘要(占位接口) + + 参数: + messages: 需要摘要的消息列表 + old_summary: 旧摘要(可选,当前未使用) + + 返回: + 摘要字符串 + + TODO: + - 实现增量摘要逻辑(基于 old_summary 生成新摘要) + - 支持摘要策略配置(长度、详细程度) + """ + return "\n".join(m.get("content")[:100] for m in messages) + +def compute_token_count(messages: List[Dict]) -> int: + """ + 计算消息的 token 数量(占位实现) + + 当前算法:4 字符 ≈ 1 token(粗略估算) + TODO:接入真实 tokenizer(如 tiktoken for OpenAI models) + """ + total_chars = 0 + for msg in messages: + total_chars += len(msg['content']) + + return max(total_chars // 4, 0) + diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index a4a76b4a6a..cabba95262 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1998,11 +1998,7 @@ const getCombinedModelById = (modelId) => { const isUserModel = combinedModel?.source === 'user'; const credential = combinedModel?.credential; - const stream = - model?.info?.params?.stream_response ?? - $settings?.params?.stream_response ?? - params?.stream_response ?? - true; + const stream = true; let messages = [ params?.system || $settings.system diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index c2e63cd81c..826dc1c529 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -1021,19 +1021,19 @@ bind:folderRegistry {folders} {shiftKey} - onDelete={(folderId) => { + onDelete={async (folderId) => { selectedFolder.set(null); - initChatList(); + await initChatList(); }} - on:update={() => { - initChatList(); + on:update={async () => { + await initChatList(); }} on:import={(e) => { const { folderId, items } = e.detail; importChatHandler(items, false, folderId); }} on:change={async () => { - initChatList(); + await initChatList(); }} /> @@ -1085,7 +1085,7 @@ const res = await toggleChatPinnedStatusById(localStorage.token, chat.id); } - initChatList(); + await initChatList(); } } else if (type === 'folder') { if (folders[id].parent_id === null) { @@ -1154,7 +1154,7 @@ const res = await toggleChatPinnedStatusById(localStorage.token, chat.id); } - initChatList(); + await initChatList(); } } }} @@ -1177,7 +1177,7 @@ selectedChatId = null; }} on:change={async () => { - initChatList(); + await initChatList(); }} on:tag={(e) => { const { type, name } = e.detail; @@ -1237,7 +1237,7 @@ selectedChatId = null; }} on:change={async () => { - initChatList(); + await initChatList(); }} on:tag={(e) => { const { type, name } = e.detail;