mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
fix: optimize summary generation logic and model compatibility
This commit is contained in:
parent
e34177ba69
commit
78b8089556
3 changed files with 368 additions and 31 deletions
|
|
@ -1655,27 +1655,71 @@ async def chat_completion(
|
|||
return
|
||||
|
||||
# 获取消息列表
|
||||
ordered = get_recent_messages_by_user_id(user.id, chat_id, 100)
|
||||
if CHAT_DEBUG_FLAG:
|
||||
print(f"[summary:init] chat_id={chat_id} 最近消息数={len(ordered)} (优先当前会话)")
|
||||
# 如果历史消息非常多(比如刚导入的),只取最近的 N 条可能会丢失早期上下文
|
||||
# 策略:分块生成摘要。先取较早的消息生成基础摘要,再结合最近消息生成最终摘要
|
||||
|
||||
if not ordered:
|
||||
# 1. 获取全量消息(按时间排序)
|
||||
all_messages = get_recent_messages_by_user_id(user.id, chat_id, -1) # -1 表示获取全部
|
||||
|
||||
if CHAT_DEBUG_FLAG:
|
||||
print(f"[summary:init] chat_id={chat_id} 全量消息数={len(all_messages)}")
|
||||
|
||||
if not all_messages:
|
||||
if CHAT_DEBUG_FLAG:
|
||||
print(f"[summary:init] chat_id={chat_id} 无可用消息,跳过生成")
|
||||
return
|
||||
|
||||
# 调用 LLM 生成摘要并保存
|
||||
summary_text = summarize(ordered, None)
|
||||
last_id = ordered[-1].get("id") if ordered else None
|
||||
recent_ids = [m.get("id") for m in ordered[-20:] if m.get("id")] # 记录最近20条消息为冷启动消息
|
||||
# 2. 如果消息量适中 (< 200),直接生成
|
||||
if len(all_messages) <= 200:
|
||||
messages_for_summary = all_messages
|
||||
model_id = form_data.get("model")
|
||||
summary_text = summarize(messages_for_summary, None, model=model_id)
|
||||
else:
|
||||
# 3. 如果消息量很大,进行分块压缩
|
||||
# 这里的策略是:
|
||||
# a. 取前 50% 的消息(或者前 N 条),生成一个 "Initial Summary"
|
||||
# b. 将这个 Initial Summary 作为 existing_summary,配合后半部分消息生成最终摘要
|
||||
|
||||
if CHAT_DEBUG_FLAG:
|
||||
print(f"[summary:init] chat_id={chat_id} 消息过多,执行分块摘要策略...")
|
||||
|
||||
# 简单的二分法:旧消息块 (除最后100条外) + 新消息块 (最后100条)
|
||||
# 这样能保证最近的对话被精细处理,而久远的对话被压缩
|
||||
split_idx = max(len(all_messages) - 100, 0)
|
||||
older_messages = all_messages[:split_idx]
|
||||
recent_messages = all_messages[split_idx:]
|
||||
|
||||
model_id = form_data.get("model")
|
||||
|
||||
# 第一步:压缩旧消息
|
||||
# 注意:如果旧消息依然过多(>500),summarize 内部会截取最后 max_messages(默认120),
|
||||
# 所以对于超大导入,我们可能需要循环压缩。但为了性能,这里暂只做一次“旧历史压缩”。
|
||||
# 为了让 summarize 能够处理更多旧消息,我们可以临时调大 max_messages 限制,
|
||||
# 或者分批调用。鉴于性能,我们只取旧消息的“精华片段”(比如每隔几条取一条,或者只取两端)。
|
||||
# 简化方案:直接把 older_messages 传给 summarize,让它内部去切片/截断
|
||||
# TODO: 真正完美的方案是循环 reduce,但耗时太久。
|
||||
|
||||
if CHAT_DEBUG_FLAG:
|
||||
print(f"[summary:init] 生成旧历史摘要 (msgs={len(older_messages)})...")
|
||||
base_summary = summarize(older_messages, None, model=model_id)
|
||||
|
||||
if CHAT_DEBUG_FLAG:
|
||||
print(f"[summary:init] 生成最终摘要 (base_summary_len={len(base_summary)}, recent_msgs={len(recent_messages)})...")
|
||||
# 第二步:基于旧摘要 + 最近消息生成最终摘要
|
||||
summary_text = summarize(recent_messages, old_summary=base_summary, model=model_id)
|
||||
|
||||
messages_for_summary = recent_messages # 用于后续计算 last_id 等
|
||||
|
||||
last_id = messages_for_summary[-1].get("id") if messages_for_summary else None
|
||||
recent_ids = [m.get("id") for m in messages_for_summary[-20:] if m.get("id")] # 记录最近20条消息为冷启动消息
|
||||
|
||||
if CHAT_DEBUG_FLAG:
|
||||
print(
|
||||
f"[summary:init] chat_id={chat_id} 生成首条摘要,msg_count={len(ordered)}, last_id={last_id}, recent_ids={len(recent_ids)}"
|
||||
f"[summary:init] chat_id={chat_id} 生成首条摘要,msg_count={len(messages_for_summary)}, last_id={last_id}, recent_ids={len(recent_ids)}"
|
||||
)
|
||||
|
||||
print("[summary:init]: ordered")
|
||||
for i in ordered:
|
||||
print("[summary:init]: messages_for_summary")
|
||||
for i in messages_for_summary:
|
||||
print(i['role'], " ", i['content'][:100])
|
||||
|
||||
res = Chats.set_summary_by_user_id_and_chat_id(
|
||||
|
|
|
|||
|
|
@ -2092,7 +2092,10 @@ async def process_chat_response(
|
|||
f"[summary:update] chat_id={chat_id} 切片总数={len(ordered)} 摘要参与消息数={len(summary_messages)} "
|
||||
f"上次摘要 message_id={existing_summary.get('last_message_id') if existing_summary else None}"
|
||||
)
|
||||
summary_text = summarize(summary_messages, old_summary)
|
||||
|
||||
# 获取当前模型ID,确保使用正确的模型进行摘要更新
|
||||
model_id = model.get("id") if model else None
|
||||
summary_text = summarize(summary_messages, old_summary, model=model_id)
|
||||
last_msg_id = (
|
||||
summary_messages[-1].get("id")
|
||||
if summary_messages
|
||||
|
|
|
|||
|
|
@ -1,7 +1,269 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Sequence, Any
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from logging import getLogger
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
OpenAI = None
|
||||
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.config import OPENAI_API_KEYS, OPENAI_API_BASE_URLS
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
# --- Constants & Prompts from persona_extractor ---
|
||||
|
||||
SUMMARY_PROMPT = """你是一名“对话历史整理员”,请在保持事实准确的前提下,概括当前为止的聊天记录。
|
||||
## 要求
|
||||
1. 最终摘要不得超过 1000 字。
|
||||
2. 聚焦人物状态、事件节点、情绪/意图等关键信息,将片段整合为连贯文字。
|
||||
3. 输出需包含 who / how / why / what 四个字段,每项不超过 50 字。
|
||||
4. 禁止臆测或脏话,所有内容都必须能在聊天中找到对应描述。
|
||||
5. 目标:帮助后续对话快速回忆上下文与人物设定。
|
||||
|
||||
已存在的摘要(如无则写“无”):
|
||||
{existing_summary}
|
||||
|
||||
聊天片段:
|
||||
---CHATS---
|
||||
{chat_transcript}
|
||||
---END---
|
||||
|
||||
请严格输出下列 JSON:
|
||||
{{
|
||||
"summary": "不超过1000字的连贯摘要",
|
||||
"table": {{
|
||||
"who": "不超过50字",
|
||||
"how": "不超过50字",
|
||||
"why": "不超过50字",
|
||||
"what": "不超过50字"
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
MERGE_ONLY_PROMPT = """你是一名“对话历史整理员”。
|
||||
请将以下两段对话摘要(A 和 B)合并为一段连贯的、更新后的对话历史摘要。
|
||||
摘要 A 是较早的时间段,摘要 B 是较新的时间段。
|
||||
|
||||
【摘要 A (旧)】
|
||||
{summary_a}
|
||||
|
||||
【摘要 B (新)】
|
||||
{summary_b}
|
||||
|
||||
## 要求
|
||||
1. 保持时间线的连贯性,将新发生的事自然接续在旧事之后。
|
||||
2. 最终摘要不得超过 1000 字。
|
||||
3. 依然提取 who / how / why / what 四个关键要素(基于合并后的全貌)。
|
||||
4. 禁止臆测,只基于提供的摘要内容。
|
||||
|
||||
请严格输出下列 JSON:
|
||||
{{
|
||||
"summary": "合并后的连贯摘要",
|
||||
"table": {{
|
||||
"who": "不超过50字",
|
||||
"how": "不超过50字",
|
||||
"why": "不超过50字",
|
||||
"what": "不超过50字"
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class HistorySummary:
|
||||
summary: str
|
||||
table: dict[str, str]
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChatMessage:
|
||||
role: str
|
||||
content: str
|
||||
timestamp: Optional[Any] = None
|
||||
|
||||
def formatted(self) -> str:
|
||||
return f"{self.role}: {self.content}"
|
||||
|
||||
class HistorySummarizer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: Optional[Any] = None,
|
||||
model: str = "gpt-4.1-mini",
|
||||
max_output_tokens: int = 800,
|
||||
temperature: float = 0.1,
|
||||
max_messages: int = 120,
|
||||
) -> None:
|
||||
if client is None:
|
||||
if OpenAI is None:
|
||||
log.warning("OpenAI client not available. Install openai>=1.0.0.")
|
||||
else:
|
||||
try:
|
||||
# 尝试从配置获取 API Key 和 Base URL
|
||||
api_keys = OPENAI_API_KEYS.value if hasattr(OPENAI_API_KEYS, "value") else []
|
||||
base_urls = OPENAI_API_BASE_URLS.value if hasattr(OPENAI_API_BASE_URLS, "value") else []
|
||||
|
||||
api_key = api_keys[0] if api_keys else os.environ.get("OPENAI_API_KEY")
|
||||
base_url = base_urls[0] if base_urls else os.environ.get("OPENAI_API_BASE_URL")
|
||||
|
||||
if api_key:
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
client = OpenAI(**kwargs)
|
||||
else:
|
||||
log.warning("No OpenAI API key found.")
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to init OpenAI client: {e}")
|
||||
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._max_output_tokens = max_output_tokens
|
||||
self._temperature = temperature
|
||||
self._max_messages = max_messages
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
messages: Sequence[Dict],
|
||||
*,
|
||||
existing_summary: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Optional[HistorySummary]:
|
||||
if not messages and not existing_summary:
|
||||
return None
|
||||
|
||||
# 转换 dict 消息为 ChatMessage 格式用于 prompt 生成
|
||||
# 确保消息按时间戳排序,防止乱序导致切片错误
|
||||
sorted_messages = sorted(messages, key=lambda m: m.get('timestamp', 0) if isinstance(m.get('timestamp'), (int, float)) else 0)
|
||||
|
||||
# 如果有 existing_summary,我们可以适当减少这里的消息量,或者依然取最近的
|
||||
# 但为了逻辑简单,我们还是取最近的 max_messages
|
||||
trail = sorted_messages[-self._max_messages :]
|
||||
transcript = "\n".join(f"{m.get('role', 'user')}: {m.get('content', '')}" for m in trail)
|
||||
|
||||
prompt = SUMMARY_PROMPT.format(
|
||||
existing_summary=existing_summary.strip() if existing_summary else "无",
|
||||
chat_transcript=transcript,
|
||||
)
|
||||
|
||||
if not self._client:
|
||||
log.error("No OpenAI client available for summarization.")
|
||||
return None
|
||||
|
||||
log.info(f"Starting summary generation for {len(messages)} messages...")
|
||||
|
||||
# Try primary client first
|
||||
try:
|
||||
# 增加 max_tokens 限制,避免摘要过长被截断,同时留给 JSON 结构足够的空间
|
||||
# 根据经验,1000 字摘要 + JSON 结构大约需要 1500 tokens
|
||||
safe_max_tokens = max(max_tokens or self._max_output_tokens, 2000)
|
||||
|
||||
response = self._client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=safe_max_tokens,
|
||||
temperature=self._temperature,
|
||||
)
|
||||
|
||||
# Debug: Print full response to investigate empty content issues
|
||||
log.info(f"Full Summary API Response: {response}")
|
||||
|
||||
payload = response.choices[0].message.content or ""
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
if finish_reason == "length":
|
||||
log.warning("Summary generation was truncated due to length limit!")
|
||||
|
||||
log.info(f"Summary generation completed. Payload length: {len(payload)}")
|
||||
log.info(f"Summary Content:\n{payload}")
|
||||
|
||||
return self._parse_response(payload)
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"Summarization failed: {e}")
|
||||
return None
|
||||
|
||||
def merge_summaries(
|
||||
self,
|
||||
summary_a: str,
|
||||
summary_b: str,
|
||||
*,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Optional[HistorySummary]:
|
||||
if not summary_a and not summary_b:
|
||||
return None
|
||||
|
||||
prompt = MERGE_ONLY_PROMPT.format(
|
||||
summary_a=summary_a or "无",
|
||||
summary_b=summary_b or "无",
|
||||
)
|
||||
|
||||
if not self._client:
|
||||
return None
|
||||
|
||||
log.info(f"Starting summary merge (A len={len(summary_a)}, B len={len(summary_b)})...")
|
||||
|
||||
# Try primary client
|
||||
try:
|
||||
response = self._client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens or self._max_output_tokens,
|
||||
temperature=self._temperature,
|
||||
)
|
||||
|
||||
payload = response.choices[0].message.content or ""
|
||||
log.info("Summary merge completed successfully.")
|
||||
return self._parse_response(payload)
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"Merge failed: {e}")
|
||||
return None
|
||||
|
||||
def _parse_response(self, payload: str) -> HistorySummary:
|
||||
data = _safe_json_loads(payload)
|
||||
|
||||
# 如果解析出的 data 是空或者不是 dict,尝试直接用 payload
|
||||
if not isinstance(data, dict) or (not data and not payload.strip().startswith("{")):
|
||||
summary = payload.strip()
|
||||
table = {}
|
||||
else:
|
||||
summary = str(data.get("summary", "")).strip()
|
||||
table_payload = data.get("table", {}) or {}
|
||||
table = {
|
||||
"who": str(table_payload.get("who", "")).strip(),
|
||||
"how": str(table_payload.get("how", "")).strip(),
|
||||
"why": str(table_payload.get("why", "")).strip(),
|
||||
"what": str(table_payload.get("what", "")).strip(),
|
||||
}
|
||||
|
||||
if not summary:
|
||||
summary = payload.strip()
|
||||
|
||||
if len(summary) > 1000:
|
||||
summary = summary[:1000].rstrip() + "..."
|
||||
|
||||
return HistorySummary(summary=summary, table=table)
|
||||
|
||||
def _safe_json_loads(raw: str) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
# 简单的正则提取尝试
|
||||
match = re.search(r'(\{.*\})', raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
# --- Core Logic Modules ---
|
||||
|
||||
def build_ordered_messages(
|
||||
messages_map: Optional[Dict], anchor_id: Optional[str] = None
|
||||
|
|
@ -58,21 +320,26 @@ def build_ordered_messages(
|
|||
|
||||
def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]:
|
||||
"""
|
||||
获取指定用户的全局最近 N 条消息(按时间顺序)
|
||||
获取指定用户的最近 N 条消息(优先当前会话,然后按时间顺序)
|
||||
|
||||
参数:
|
||||
user_id: 用户 ID
|
||||
chat_id: 当前会话 ID(用于优先提取)
|
||||
num: 需要获取的消息数量(<= 0 时返回全部)
|
||||
|
||||
返回:
|
||||
有序的消息列表(最近的 num 条)
|
||||
有序的消息列表(优先当前会话,不足时由全局最近补齐)
|
||||
"""
|
||||
all_messages: List[Dict] = []
|
||||
current_chat_messages: List[Dict] = []
|
||||
other_messages: List[Dict] = []
|
||||
|
||||
# 遍历用户的所有聊天
|
||||
chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True)
|
||||
for chat in chats:
|
||||
messages_map = chat.chat.get("history", {}).get("messages", {}) or {}
|
||||
# 简单判断是否为当前会话
|
||||
is_current_chat = (str(chat.id) == str(chat_id))
|
||||
|
||||
for mid, msg in messages_map.items():
|
||||
# 跳过空内容
|
||||
if msg.get("content", "") == "":
|
||||
|
|
@ -86,15 +353,34 @@ def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List
|
|||
entry = {**msg, "id": mid}
|
||||
entry.setdefault("chat_id", chat.id)
|
||||
entry.setdefault("timestamp", int(ts))
|
||||
all_messages.append(entry)
|
||||
|
||||
# 按时间戳排序
|
||||
all_messages.sort(key=lambda m: m.get("timestamp", 0))
|
||||
if is_current_chat:
|
||||
current_chat_messages.append(entry)
|
||||
else:
|
||||
other_messages.append(entry)
|
||||
|
||||
# 分别排序
|
||||
current_chat_messages.sort(key=lambda m: m.get("timestamp", 0))
|
||||
other_messages.sort(key=lambda m: m.get("timestamp", 0))
|
||||
|
||||
if num <= 0:
|
||||
return all_messages
|
||||
combined = current_chat_messages + other_messages
|
||||
combined.sort(key=lambda m: m.get("timestamp", 0))
|
||||
return combined
|
||||
|
||||
return all_messages[-num:]
|
||||
# 策略:优先保留当前会话消息
|
||||
if len(current_chat_messages) >= num:
|
||||
return current_chat_messages[-num:]
|
||||
|
||||
# 补充不足的部分
|
||||
needed = num - len(current_chat_messages)
|
||||
supplement = other_messages[-needed:] if other_messages else []
|
||||
|
||||
# 合并并最终按时间排序
|
||||
final_list = supplement + current_chat_messages
|
||||
final_list.sort(key=lambda m: m.get("timestamp", 0))
|
||||
|
||||
return final_list
|
||||
|
||||
|
||||
def slice_messages_with_summary(
|
||||
|
|
@ -140,22 +426,21 @@ def slice_messages_with_summary(
|
|||
return ordered
|
||||
|
||||
|
||||
def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str:
|
||||
def summarize(messages: List[Dict], old_summary: Optional[str] = None, model: Optional[str] = None) -> str:
|
||||
"""
|
||||
生成对话摘要(占位接口)
|
||||
生成对话摘要
|
||||
|
||||
参数:
|
||||
messages: 需要摘要的消息列表
|
||||
old_summary: 旧摘要(可选,当前未使用)
|
||||
old_summary: 旧摘要
|
||||
model: 指定使用的模型 ID(如果为 None,则使用类内部默认值)
|
||||
|
||||
返回:
|
||||
摘要字符串
|
||||
|
||||
TODO:
|
||||
- 实现增量摘要逻辑(基于 old_summary 生成新摘要)
|
||||
- 支持摘要策略配置(长度、详细程度)
|
||||
"""
|
||||
return "\n".join(m.get("content")[:100] for m in messages)
|
||||
summarizer = HistorySummarizer(model=model) if model else HistorySummarizer()
|
||||
result = summarizer.summarize(messages, existing_summary=old_summary)
|
||||
return result.summary if result else ""
|
||||
|
||||
def compute_token_count(messages: List[Dict]) -> int:
|
||||
"""
|
||||
|
|
@ -166,7 +451,12 @@ def compute_token_count(messages: List[Dict]) -> int:
|
|||
"""
|
||||
total_chars = 0
|
||||
for msg in messages:
|
||||
total_chars += len(msg['content'])
|
||||
content = msg.get('content')
|
||||
if isinstance(content, str):
|
||||
total_chars += len(content)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and 'text' in item:
|
||||
total_chars += len(item['text'])
|
||||
|
||||
return max(total_chars // 4, 0)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue