From 78b80895565a905a22b91beec65ed16427d7581f Mon Sep 17 00:00:00 2001 From: ChuangWang Date: Wed, 10 Dec 2025 04:10:47 +0800 Subject: [PATCH] fix: optimize summary generation logic and model compatibility --- backend/open_webui/main.py | 64 ++++- backend/open_webui/utils/middleware.py | 5 +- backend/open_webui/utils/summary.py | 330 +++++++++++++++++++++++-- 3 files changed, 368 insertions(+), 31 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 41ee380ef7..f1b5a5182e 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1655,27 +1655,71 @@ async def chat_completion( return # 获取消息列表 - ordered = get_recent_messages_by_user_id(user.id, chat_id, 100) + # 如果历史消息非常多(比如刚导入的),只取最近的 N 条可能会丢失早期上下文 + # 策略:分块生成摘要。先取较早的消息生成基础摘要,再结合最近消息生成最终摘要 + + # 1. 获取全量消息(按时间排序) + all_messages = get_recent_messages_by_user_id(user.id, chat_id, -1) # -1 表示获取全部 + if CHAT_DEBUG_FLAG: - print(f"[summary:init] chat_id={chat_id} 最近消息数={len(ordered)} (优先当前会话)") + print(f"[summary:init] chat_id={chat_id} 全量消息数={len(all_messages)}") - if not ordered: + if not all_messages: if CHAT_DEBUG_FLAG: print(f"[summary:init] chat_id={chat_id} 无可用消息,跳过生成") return - # 调用 LLM 生成摘要并保存 - summary_text = summarize(ordered, None) - last_id = ordered[-1].get("id") if ordered else None - recent_ids = [m.get("id") for m in ordered[-20:] if m.get("id")] # 记录最近20条消息为冷启动消息 + # 2. 如果消息量适中 (< 200),直接生成 + if len(all_messages) <= 200: + messages_for_summary = all_messages + model_id = form_data.get("model") + summary_text = summarize(messages_for_summary, None, model=model_id) + else: + # 3. 如果消息量很大,进行分块压缩 + # 这里的策略是: + # a. 取前 50% 的消息(或者前 N 条),生成一个 "Initial Summary" + # b. 将这个 Initial Summary 作为 existing_summary,配合后半部分消息生成最终摘要 + + if CHAT_DEBUG_FLAG: + print(f"[summary:init] chat_id={chat_id} 消息过多,执行分块摘要策略...") + + # 简单的二分法:旧消息块 (除最后100条外) + 新消息块 (最后100条) + # 这样能保证最近的对话被精细处理,而久远的对话被压缩 + split_idx = max(len(all_messages) - 100, 0) + older_messages = all_messages[:split_idx] + recent_messages = all_messages[split_idx:] + + model_id = form_data.get("model") + + # 第一步:压缩旧消息 + # 注意:如果旧消息依然过多(>500),summarize 内部会截取最后 max_messages(默认120), + # 所以对于超大导入,我们可能需要循环压缩。但为了性能,这里暂只做一次“旧历史压缩”。 + # 为了让 summarize 能够处理更多旧消息,我们可以临时调大 max_messages 限制, + # 或者分批调用。鉴于性能,我们只取旧消息的“精华片段”(比如每隔几条取一条,或者只取两端)。 + # 简化方案:直接把 older_messages 传给 summarize,让它内部去切片/截断 + # TODO: 真正完美的方案是循环 reduce,但耗时太久。 + + if CHAT_DEBUG_FLAG: + print(f"[summary:init] 生成旧历史摘要 (msgs={len(older_messages)})...") + base_summary = summarize(older_messages, None, model=model_id) + + if CHAT_DEBUG_FLAG: + print(f"[summary:init] 生成最终摘要 (base_summary_len={len(base_summary)}, recent_msgs={len(recent_messages)})...") + # 第二步:基于旧摘要 + 最近消息生成最终摘要 + summary_text = summarize(recent_messages, old_summary=base_summary, model=model_id) + + messages_for_summary = recent_messages # 用于后续计算 last_id 等 + + last_id = messages_for_summary[-1].get("id") if messages_for_summary else None + recent_ids = [m.get("id") for m in messages_for_summary[-20:] if m.get("id")] # 记录最近20条消息为冷启动消息 if CHAT_DEBUG_FLAG: print( - f"[summary:init] chat_id={chat_id} 生成首条摘要,msg_count={len(ordered)}, last_id={last_id}, recent_ids={len(recent_ids)}" + f"[summary:init] chat_id={chat_id} 生成首条摘要,msg_count={len(messages_for_summary)}, last_id={last_id}, recent_ids={len(recent_ids)}" ) - print("[summary:init]: ordered") - for i in ordered: + print("[summary:init]: messages_for_summary") + for i in messages_for_summary: print(i['role'], " ", i['content'][:100]) res = Chats.set_summary_by_user_id_and_chat_id( diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index d72a5b6ece..9a484cebbe 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -2092,7 +2092,10 @@ async def process_chat_response( f"[summary:update] chat_id={chat_id} 切片总数={len(ordered)} 摘要参与消息数={len(summary_messages)} " f"上次摘要 message_id={existing_summary.get('last_message_id') if existing_summary else None}" ) - summary_text = summarize(summary_messages, old_summary) + + # 获取当前模型ID,确保使用正确的模型进行摘要更新 + model_id = model.get("id") if model else None + summary_text = summarize(summary_messages, old_summary, model=model_id) last_msg_id = ( summary_messages[-1].get("id") if summary_messages diff --git a/backend/open_webui/utils/summary.py b/backend/open_webui/utils/summary.py index 9ac493a1b3..54e9f4d5c6 100644 --- a/backend/open_webui/utils/summary.py +++ b/backend/open_webui/utils/summary.py @@ -1,7 +1,269 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Sequence, Any +import json +import re +import os +from dataclasses import dataclass +from logging import getLogger + +try: + from openai import OpenAI +except ImportError: + OpenAI = None from open_webui.models.chats import Chats +from open_webui.config import OPENAI_API_KEYS, OPENAI_API_BASE_URLS +log = getLogger(__name__) + +# --- Constants & Prompts from persona_extractor --- + +SUMMARY_PROMPT = """你是一名“对话历史整理员”,请在保持事实准确的前提下,概括当前为止的聊天记录。 +## 要求 +1. 最终摘要不得超过 1000 字。 +2. 聚焦人物状态、事件节点、情绪/意图等关键信息,将片段整合为连贯文字。 +3. 输出需包含 who / how / why / what 四个字段,每项不超过 50 字。 +4. 禁止臆测或脏话,所有内容都必须能在聊天中找到对应描述。 +5. 目标:帮助后续对话快速回忆上下文与人物设定。 + +已存在的摘要(如无则写“无”): +{existing_summary} + +聊天片段: +---CHATS--- +{chat_transcript} +---END--- + +请严格输出下列 JSON: +{{ + "summary": "不超过1000字的连贯摘要", + "table": {{ + "who": "不超过50字", + "how": "不超过50字", + "why": "不超过50字", + "what": "不超过50字" + }} +}} +""" + +MERGE_ONLY_PROMPT = """你是一名“对话历史整理员”。 +请将以下两段对话摘要(A 和 B)合并为一段连贯的、更新后的对话历史摘要。 +摘要 A 是较早的时间段,摘要 B 是较新的时间段。 + +【摘要 A (旧)】 +{summary_a} + +【摘要 B (新)】 +{summary_b} + +## 要求 +1. 保持时间线的连贯性,将新发生的事自然接续在旧事之后。 +2. 最终摘要不得超过 1000 字。 +3. 依然提取 who / how / why / what 四个关键要素(基于合并后的全貌)。 +4. 禁止臆测,只基于提供的摘要内容。 + +请严格输出下列 JSON: +{{ + "summary": "合并后的连贯摘要", + "table": {{ + "who": "不超过50字", + "how": "不超过50字", + "why": "不超过50字", + "what": "不超过50字" + }} +}} +""" + +@dataclass +class HistorySummary: + summary: str + table: dict[str, str] + +@dataclass(slots=True) +class ChatMessage: + role: str + content: str + timestamp: Optional[Any] = None + + def formatted(self) -> str: + return f"{self.role}: {self.content}" + +class HistorySummarizer: + def __init__( + self, + *, + client: Optional[Any] = None, + model: str = "gpt-4.1-mini", + max_output_tokens: int = 800, + temperature: float = 0.1, + max_messages: int = 120, + ) -> None: + if client is None: + if OpenAI is None: + log.warning("OpenAI client not available. Install openai>=1.0.0.") + else: + try: + # 尝试从配置获取 API Key 和 Base URL + api_keys = OPENAI_API_KEYS.value if hasattr(OPENAI_API_KEYS, "value") else [] + base_urls = OPENAI_API_BASE_URLS.value if hasattr(OPENAI_API_BASE_URLS, "value") else [] + + api_key = api_keys[0] if api_keys else os.environ.get("OPENAI_API_KEY") + base_url = base_urls[0] if base_urls else os.environ.get("OPENAI_API_BASE_URL") + + if api_key: + kwargs = {"api_key": api_key} + if base_url: + kwargs["base_url"] = base_url + client = OpenAI(**kwargs) + else: + log.warning("No OpenAI API key found.") + + except Exception as e: + log.warning(f"Failed to init OpenAI client: {e}") + + self._client = client + self._model = model + self._max_output_tokens = max_output_tokens + self._temperature = temperature + self._max_messages = max_messages + + def summarize( + self, + messages: Sequence[Dict], + *, + existing_summary: Optional[str] = None, + max_tokens: Optional[int] = None, + ) -> Optional[HistorySummary]: + if not messages and not existing_summary: + return None + + # 转换 dict 消息为 ChatMessage 格式用于 prompt 生成 + # 确保消息按时间戳排序,防止乱序导致切片错误 + sorted_messages = sorted(messages, key=lambda m: m.get('timestamp', 0) if isinstance(m.get('timestamp'), (int, float)) else 0) + + # 如果有 existing_summary,我们可以适当减少这里的消息量,或者依然取最近的 + # 但为了逻辑简单,我们还是取最近的 max_messages + trail = sorted_messages[-self._max_messages :] + transcript = "\n".join(f"{m.get('role', 'user')}: {m.get('content', '')}" for m in trail) + + prompt = SUMMARY_PROMPT.format( + existing_summary=existing_summary.strip() if existing_summary else "无", + chat_transcript=transcript, + ) + + if not self._client: + log.error("No OpenAI client available for summarization.") + return None + + log.info(f"Starting summary generation for {len(messages)} messages...") + + # Try primary client first + try: + # 增加 max_tokens 限制,避免摘要过长被截断,同时留给 JSON 结构足够的空间 + # 根据经验,1000 字摘要 + JSON 结构大约需要 1500 tokens + safe_max_tokens = max(max_tokens or self._max_output_tokens, 2000) + + response = self._client.chat.completions.create( + model=self._model, + messages=[{"role": "user", "content": prompt}], + max_tokens=safe_max_tokens, + temperature=self._temperature, + ) + + # Debug: Print full response to investigate empty content issues + log.info(f"Full Summary API Response: {response}") + + payload = response.choices[0].message.content or "" + finish_reason = response.choices[0].finish_reason + + if finish_reason == "length": + log.warning("Summary generation was truncated due to length limit!") + + log.info(f"Summary generation completed. Payload length: {len(payload)}") + log.info(f"Summary Content:\n{payload}") + + return self._parse_response(payload) + + except Exception as e: + log.warning(f"Summarization failed: {e}") + return None + + def merge_summaries( + self, + summary_a: str, + summary_b: str, + *, + max_tokens: Optional[int] = None, + ) -> Optional[HistorySummary]: + if not summary_a and not summary_b: + return None + + prompt = MERGE_ONLY_PROMPT.format( + summary_a=summary_a or "无", + summary_b=summary_b or "无", + ) + + if not self._client: + return None + + log.info(f"Starting summary merge (A len={len(summary_a)}, B len={len(summary_b)})...") + + # Try primary client + try: + response = self._client.chat.completions.create( + model=self._model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens or self._max_output_tokens, + temperature=self._temperature, + ) + + payload = response.choices[0].message.content or "" + log.info("Summary merge completed successfully.") + return self._parse_response(payload) + + except Exception as e: + log.warning(f"Merge failed: {e}") + return None + + def _parse_response(self, payload: str) -> HistorySummary: + data = _safe_json_loads(payload) + + # 如果解析出的 data 是空或者不是 dict,尝试直接用 payload + if not isinstance(data, dict) or (not data and not payload.strip().startswith("{")): + summary = payload.strip() + table = {} + else: + summary = str(data.get("summary", "")).strip() + table_payload = data.get("table", {}) or {} + table = { + "who": str(table_payload.get("who", "")).strip(), + "how": str(table_payload.get("how", "")).strip(), + "why": str(table_payload.get("why", "")).strip(), + "what": str(table_payload.get("what", "")).strip(), + } + + if not summary: + summary = payload.strip() + + if len(summary) > 1000: + summary = summary[:1000].rstrip() + "..." + + return HistorySummary(summary=summary, table=table) + +def _safe_json_loads(raw: str) -> Dict[str, Any]: + try: + return json.loads(raw) + except json.JSONDecodeError: + # 简单的正则提取尝试 + match = re.search(r'(\{.*\})', raw, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except json.JSONDecodeError: + pass + return {} + + +# --- Core Logic Modules --- def build_ordered_messages( messages_map: Optional[Dict], anchor_id: Optional[str] = None @@ -58,21 +320,26 @@ def build_ordered_messages( def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]: """ - 获取指定用户的全局最近 N 条消息(按时间顺序) + 获取指定用户的最近 N 条消息(优先当前会话,然后按时间顺序) 参数: user_id: 用户 ID + chat_id: 当前会话 ID(用于优先提取) num: 需要获取的消息数量(<= 0 时返回全部) 返回: - 有序的消息列表(最近的 num 条) + 有序的消息列表(优先当前会话,不足时由全局最近补齐) """ - all_messages: List[Dict] = [] + current_chat_messages: List[Dict] = [] + other_messages: List[Dict] = [] # 遍历用户的所有聊天 chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True) for chat in chats: messages_map = chat.chat.get("history", {}).get("messages", {}) or {} + # 简单判断是否为当前会话 + is_current_chat = (str(chat.id) == str(chat_id)) + for mid, msg in messages_map.items(): # 跳过空内容 if msg.get("content", "") == "": @@ -86,15 +353,34 @@ def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List entry = {**msg, "id": mid} entry.setdefault("chat_id", chat.id) entry.setdefault("timestamp", int(ts)) - all_messages.append(entry) + + if is_current_chat: + current_chat_messages.append(entry) + else: + other_messages.append(entry) - # 按时间戳排序 - all_messages.sort(key=lambda m: m.get("timestamp", 0)) + # 分别排序 + current_chat_messages.sort(key=lambda m: m.get("timestamp", 0)) + other_messages.sort(key=lambda m: m.get("timestamp", 0)) if num <= 0: - return all_messages + combined = current_chat_messages + other_messages + combined.sort(key=lambda m: m.get("timestamp", 0)) + return combined - return all_messages[-num:] + # 策略:优先保留当前会话消息 + if len(current_chat_messages) >= num: + return current_chat_messages[-num:] + + # 补充不足的部分 + needed = num - len(current_chat_messages) + supplement = other_messages[-needed:] if other_messages else [] + + # 合并并最终按时间排序 + final_list = supplement + current_chat_messages + final_list.sort(key=lambda m: m.get("timestamp", 0)) + + return final_list def slice_messages_with_summary( @@ -140,22 +426,21 @@ def slice_messages_with_summary( return ordered -def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str: +def summarize(messages: List[Dict], old_summary: Optional[str] = None, model: Optional[str] = None) -> str: """ - 生成对话摘要(占位接口) + 生成对话摘要 参数: messages: 需要摘要的消息列表 - old_summary: 旧摘要(可选,当前未使用) + old_summary: 旧摘要 + model: 指定使用的模型 ID(如果为 None,则使用类内部默认值) 返回: 摘要字符串 - - TODO: - - 实现增量摘要逻辑(基于 old_summary 生成新摘要) - - 支持摘要策略配置(长度、详细程度) """ - return "\n".join(m.get("content")[:100] for m in messages) + summarizer = HistorySummarizer(model=model) if model else HistorySummarizer() + result = summarizer.summarize(messages, existing_summary=old_summary) + return result.summary if result else "" def compute_token_count(messages: List[Dict]) -> int: """ @@ -166,7 +451,12 @@ def compute_token_count(messages: List[Dict]) -> int: """ total_chars = 0 for msg in messages: - total_chars += len(msg['content']) - + content = msg.get('content') + if isinstance(content, str): + total_chars += len(content) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and 'text' in item: + total_chars += len(item['text']) + return max(total_chars // 4, 0) -