fix: optimize summary generation logic and model compatibility

This commit is contained in:
ChuangWang 2025-12-10 04:10:47 +08:00
parent e34177ba69
commit 78b8089556
3 changed files with 368 additions and 31 deletions

View file

@ -1655,27 +1655,71 @@ async def chat_completion(
return
# 获取消息列表
ordered = get_recent_messages_by_user_id(user.id, chat_id, 100)
# 如果历史消息非常多(比如刚导入的),只取最近的 N 条可能会丢失早期上下文
# 策略:分块生成摘要。先取较早的消息生成基础摘要,再结合最近消息生成最终摘要
# 1. 获取全量消息(按时间排序)
all_messages = get_recent_messages_by_user_id(user.id, chat_id, -1) # -1 表示获取全部
if CHAT_DEBUG_FLAG:
print(f"[summary:init] chat_id={chat_id} 最近消息数={len(ordered)} (优先当前会话)")
print(f"[summary:init] chat_id={chat_id} 全量消息数={len(all_messages)}")
if not ordered:
if not all_messages:
if CHAT_DEBUG_FLAG:
print(f"[summary:init] chat_id={chat_id} 无可用消息,跳过生成")
return
# 调用 LLM 生成摘要并保存
summary_text = summarize(ordered, None)
last_id = ordered[-1].get("id") if ordered else None
recent_ids = [m.get("id") for m in ordered[-20:] if m.get("id")] # 记录最近20条消息为冷启动消息
# 2. 如果消息量适中 (< 200),直接生成
if len(all_messages) <= 200:
messages_for_summary = all_messages
model_id = form_data.get("model")
summary_text = summarize(messages_for_summary, None, model=model_id)
else:
# 3. 如果消息量很大,进行分块压缩
# 这里的策略是:
# a. 取前 50% 的消息(或者前 N 条),生成一个 "Initial Summary"
# b. 将这个 Initial Summary 作为 existing_summary配合后半部分消息生成最终摘要
if CHAT_DEBUG_FLAG:
print(f"[summary:init] chat_id={chat_id} 消息过多,执行分块摘要策略...")
# 简单的二分法:旧消息块 (除最后100条外) + 新消息块 (最后100条)
# 这样能保证最近的对话被精细处理,而久远的对话被压缩
split_idx = max(len(all_messages) - 100, 0)
older_messages = all_messages[:split_idx]
recent_messages = all_messages[split_idx:]
model_id = form_data.get("model")
# 第一步:压缩旧消息
# 注意:如果旧消息依然过多(>500summarize 内部会截取最后 max_messages默认120
# 所以对于超大导入,我们可能需要循环压缩。但为了性能,这里暂只做一次“旧历史压缩”。
# 为了让 summarize 能够处理更多旧消息,我们可以临时调大 max_messages 限制,
# 或者分批调用。鉴于性能,我们只取旧消息的“精华片段”(比如每隔几条取一条,或者只取两端)。
# 简化方案:直接把 older_messages 传给 summarize让它内部去切片/截断
# TODO: 真正完美的方案是循环 reduce但耗时太久。
if CHAT_DEBUG_FLAG:
print(f"[summary:init] 生成旧历史摘要 (msgs={len(older_messages)})...")
base_summary = summarize(older_messages, None, model=model_id)
if CHAT_DEBUG_FLAG:
print(f"[summary:init] 生成最终摘要 (base_summary_len={len(base_summary)}, recent_msgs={len(recent_messages)})...")
# 第二步:基于旧摘要 + 最近消息生成最终摘要
summary_text = summarize(recent_messages, old_summary=base_summary, model=model_id)
messages_for_summary = recent_messages # 用于后续计算 last_id 等
last_id = messages_for_summary[-1].get("id") if messages_for_summary else None
recent_ids = [m.get("id") for m in messages_for_summary[-20:] if m.get("id")] # 记录最近20条消息为冷启动消息
if CHAT_DEBUG_FLAG:
print(
f"[summary:init] chat_id={chat_id} 生成首条摘要msg_count={len(ordered)}, last_id={last_id}, recent_ids={len(recent_ids)}"
f"[summary:init] chat_id={chat_id} 生成首条摘要msg_count={len(messages_for_summary)}, last_id={last_id}, recent_ids={len(recent_ids)}"
)
print("[summary:init]: ordered")
for i in ordered:
print("[summary:init]: messages_for_summary")
for i in messages_for_summary:
print(i['role'], " ", i['content'][:100])
res = Chats.set_summary_by_user_id_and_chat_id(

View file

@ -2092,7 +2092,10 @@ async def process_chat_response(
f"[summary:update] chat_id={chat_id} 切片总数={len(ordered)} 摘要参与消息数={len(summary_messages)} "
f"上次摘要 message_id={existing_summary.get('last_message_id') if existing_summary else None}"
)
summary_text = summarize(summary_messages, old_summary)
# 获取当前模型ID确保使用正确的模型进行摘要更新
model_id = model.get("id") if model else None
summary_text = summarize(summary_messages, old_summary, model=model_id)
last_msg_id = (
summary_messages[-1].get("id")
if summary_messages

View file

@ -1,7 +1,269 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Sequence, Any
import json
import re
import os
from dataclasses import dataclass
from logging import getLogger
try:
from openai import OpenAI
except ImportError:
OpenAI = None
from open_webui.models.chats import Chats
from open_webui.config import OPENAI_API_KEYS, OPENAI_API_BASE_URLS
log = getLogger(__name__)
# --- Constants & Prompts from persona_extractor ---
SUMMARY_PROMPT = """你是一名“对话历史整理员”,请在保持事实准确的前提下,概括当前为止的聊天记录。
## 要求
1. 最终摘要不得超过 1000
2. 聚焦人物状态事件节点情绪/意图等关键信息将片段整合为连贯文字
3. 输出需包含 who / how / why / what 四个字段每项不超过 50
4. 禁止臆测或脏话所有内容都必须能在聊天中找到对应描述
5. 目标帮助后续对话快速回忆上下文与人物设定
已存在的摘要如无则写
{existing_summary}
聊天片段
---CHATS---
{chat_transcript}
---END---
请严格输出下列 JSON
{{
"summary": "不超过1000字的连贯摘要",
"table": {{
"who": "不超过50字",
"how": "不超过50字",
"why": "不超过50字",
"what": "不超过50字"
}}
}}
"""
MERGE_ONLY_PROMPT = """你是一名“对话历史整理员”。
请将以下两段对话摘要A B合并为一段连贯的更新后的对话历史摘要
摘要 A 是较早的时间段摘要 B 是较新的时间段
摘要 A ()
{summary_a}
摘要 B ()
{summary_b}
## 要求
1. 保持时间线的连贯性将新发生的事自然接续在旧事之后
2. 最终摘要不得超过 1000
3. 依然提取 who / how / why / what 四个关键要素基于合并后的全貌
4. 禁止臆测只基于提供的摘要内容
请严格输出下列 JSON
{{
"summary": "合并后的连贯摘要",
"table": {{
"who": "不超过50字",
"how": "不超过50字",
"why": "不超过50字",
"what": "不超过50字"
}}
}}
"""
@dataclass
class HistorySummary:
summary: str
table: dict[str, str]
@dataclass(slots=True)
class ChatMessage:
role: str
content: str
timestamp: Optional[Any] = None
def formatted(self) -> str:
return f"{self.role}: {self.content}"
class HistorySummarizer:
def __init__(
self,
*,
client: Optional[Any] = None,
model: str = "gpt-4.1-mini",
max_output_tokens: int = 800,
temperature: float = 0.1,
max_messages: int = 120,
) -> None:
if client is None:
if OpenAI is None:
log.warning("OpenAI client not available. Install openai>=1.0.0.")
else:
try:
# 尝试从配置获取 API Key 和 Base URL
api_keys = OPENAI_API_KEYS.value if hasattr(OPENAI_API_KEYS, "value") else []
base_urls = OPENAI_API_BASE_URLS.value if hasattr(OPENAI_API_BASE_URLS, "value") else []
api_key = api_keys[0] if api_keys else os.environ.get("OPENAI_API_KEY")
base_url = base_urls[0] if base_urls else os.environ.get("OPENAI_API_BASE_URL")
if api_key:
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
client = OpenAI(**kwargs)
else:
log.warning("No OpenAI API key found.")
except Exception as e:
log.warning(f"Failed to init OpenAI client: {e}")
self._client = client
self._model = model
self._max_output_tokens = max_output_tokens
self._temperature = temperature
self._max_messages = max_messages
def summarize(
self,
messages: Sequence[Dict],
*,
existing_summary: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> Optional[HistorySummary]:
if not messages and not existing_summary:
return None
# 转换 dict 消息为 ChatMessage 格式用于 prompt 生成
# 确保消息按时间戳排序,防止乱序导致切片错误
sorted_messages = sorted(messages, key=lambda m: m.get('timestamp', 0) if isinstance(m.get('timestamp'), (int, float)) else 0)
# 如果有 existing_summary我们可以适当减少这里的消息量或者依然取最近的
# 但为了逻辑简单,我们还是取最近的 max_messages
trail = sorted_messages[-self._max_messages :]
transcript = "\n".join(f"{m.get('role', 'user')}: {m.get('content', '')}" for m in trail)
prompt = SUMMARY_PROMPT.format(
existing_summary=existing_summary.strip() if existing_summary else "",
chat_transcript=transcript,
)
if not self._client:
log.error("No OpenAI client available for summarization.")
return None
log.info(f"Starting summary generation for {len(messages)} messages...")
# Try primary client first
try:
# 增加 max_tokens 限制,避免摘要过长被截断,同时留给 JSON 结构足够的空间
# 根据经验1000 字摘要 + JSON 结构大约需要 1500 tokens
safe_max_tokens = max(max_tokens or self._max_output_tokens, 2000)
response = self._client.chat.completions.create(
model=self._model,
messages=[{"role": "user", "content": prompt}],
max_tokens=safe_max_tokens,
temperature=self._temperature,
)
# Debug: Print full response to investigate empty content issues
log.info(f"Full Summary API Response: {response}")
payload = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason
if finish_reason == "length":
log.warning("Summary generation was truncated due to length limit!")
log.info(f"Summary generation completed. Payload length: {len(payload)}")
log.info(f"Summary Content:\n{payload}")
return self._parse_response(payload)
except Exception as e:
log.warning(f"Summarization failed: {e}")
return None
def merge_summaries(
self,
summary_a: str,
summary_b: str,
*,
max_tokens: Optional[int] = None,
) -> Optional[HistorySummary]:
if not summary_a and not summary_b:
return None
prompt = MERGE_ONLY_PROMPT.format(
summary_a=summary_a or "",
summary_b=summary_b or "",
)
if not self._client:
return None
log.info(f"Starting summary merge (A len={len(summary_a)}, B len={len(summary_b)})...")
# Try primary client
try:
response = self._client.chat.completions.create(
model=self._model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens or self._max_output_tokens,
temperature=self._temperature,
)
payload = response.choices[0].message.content or ""
log.info("Summary merge completed successfully.")
return self._parse_response(payload)
except Exception as e:
log.warning(f"Merge failed: {e}")
return None
def _parse_response(self, payload: str) -> HistorySummary:
data = _safe_json_loads(payload)
# 如果解析出的 data 是空或者不是 dict尝试直接用 payload
if not isinstance(data, dict) or (not data and not payload.strip().startswith("{")):
summary = payload.strip()
table = {}
else:
summary = str(data.get("summary", "")).strip()
table_payload = data.get("table", {}) or {}
table = {
"who": str(table_payload.get("who", "")).strip(),
"how": str(table_payload.get("how", "")).strip(),
"why": str(table_payload.get("why", "")).strip(),
"what": str(table_payload.get("what", "")).strip(),
}
if not summary:
summary = payload.strip()
if len(summary) > 1000:
summary = summary[:1000].rstrip() + "..."
return HistorySummary(summary=summary, table=table)
def _safe_json_loads(raw: str) -> Dict[str, Any]:
try:
return json.loads(raw)
except json.JSONDecodeError:
# 简单的正则提取尝试
match = re.search(r'(\{.*\})', raw, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
return {}
# --- Core Logic Modules ---
def build_ordered_messages(
messages_map: Optional[Dict], anchor_id: Optional[str] = None
@ -58,21 +320,26 @@ def build_ordered_messages(
def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]:
"""
获取指定用户的全局最近 N 条消息按时间顺序
获取指定用户的最近 N 条消息优先当前会话然后按时间顺序
参数
user_id: 用户 ID
chat_id: 当前会话 ID用于优先提取
num: 需要获取的消息数量<= 0 时返回全部
返回
有序的消息列表最近的 num
有序的消息列表优先当前会话不足时由全局最近补齐
"""
all_messages: List[Dict] = []
current_chat_messages: List[Dict] = []
other_messages: List[Dict] = []
# 遍历用户的所有聊天
chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True)
for chat in chats:
messages_map = chat.chat.get("history", {}).get("messages", {}) or {}
# 简单判断是否为当前会话
is_current_chat = (str(chat.id) == str(chat_id))
for mid, msg in messages_map.items():
# 跳过空内容
if msg.get("content", "") == "":
@ -86,15 +353,34 @@ def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List
entry = {**msg, "id": mid}
entry.setdefault("chat_id", chat.id)
entry.setdefault("timestamp", int(ts))
all_messages.append(entry)
if is_current_chat:
current_chat_messages.append(entry)
else:
other_messages.append(entry)
# 按时间戳排序
all_messages.sort(key=lambda m: m.get("timestamp", 0))
# 分别排序
current_chat_messages.sort(key=lambda m: m.get("timestamp", 0))
other_messages.sort(key=lambda m: m.get("timestamp", 0))
if num <= 0:
return all_messages
combined = current_chat_messages + other_messages
combined.sort(key=lambda m: m.get("timestamp", 0))
return combined
return all_messages[-num:]
# 策略:优先保留当前会话消息
if len(current_chat_messages) >= num:
return current_chat_messages[-num:]
# 补充不足的部分
needed = num - len(current_chat_messages)
supplement = other_messages[-needed:] if other_messages else []
# 合并并最终按时间排序
final_list = supplement + current_chat_messages
final_list.sort(key=lambda m: m.get("timestamp", 0))
return final_list
def slice_messages_with_summary(
@ -140,22 +426,21 @@ def slice_messages_with_summary(
return ordered
def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str:
def summarize(messages: List[Dict], old_summary: Optional[str] = None, model: Optional[str] = None) -> str:
"""
生成对话摘要占位接口
生成对话摘要
参数
messages: 需要摘要的消息列表
old_summary: 旧摘要可选当前未使用
old_summary: 旧摘要
model: 指定使用的模型 ID如果为 None则使用类内部默认值
返回
摘要字符串
TODO
- 实现增量摘要逻辑基于 old_summary 生成新摘要
- 支持摘要策略配置长度详细程度
"""
return "\n".join(m.get("content")[:100] for m in messages)
summarizer = HistorySummarizer(model=model) if model else HistorySummarizer()
result = summarizer.summarize(messages, existing_summary=old_summary)
return result.summary if result else ""
def compute_token_count(messages: List[Dict]) -> int:
"""
@ -166,7 +451,12 @@ def compute_token_count(messages: List[Dict]) -> int:
"""
total_chars = 0
for msg in messages:
total_chars += len(msg['content'])
content = msg.get('content')
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and 'text' in item:
total_chars += len(item['text'])
return max(total_chars // 4, 0)