open-webui/backend/open_webui/utils/summary.py

463 lines
16 KiB
Python
Raw Normal View History

from typing import Dict, List, Optional, Tuple, Sequence, Any
import json
import re
import os
from dataclasses import dataclass
from logging import getLogger
try:
from openai import OpenAI
except ImportError:
OpenAI = None
from open_webui.models.chats import Chats
from open_webui.config import OPENAI_API_KEYS, OPENAI_API_BASE_URLS
log = getLogger(__name__)
# --- Constants & Prompts from persona_extractor ---
SUMMARY_PROMPT = """你是一名“对话历史整理员”,请在保持事实准确的前提下,概括当前为止的聊天记录。
## 要求
1. 最终摘要不得超过 1000
2. 聚焦人物状态事件节点情绪/意图等关键信息将片段整合为连贯文字
3. 输出需包含 who / how / why / what 四个字段每项不超过 50
4. 禁止臆测或脏话所有内容都必须能在聊天中找到对应描述
5. 目标帮助后续对话快速回忆上下文与人物设定
已存在的摘要如无则写
{existing_summary}
聊天片段
---CHATS---
{chat_transcript}
---END---
请严格输出下列 JSON
{{
"summary": "不超过1000字的连贯摘要",
"table": {{
"who": "不超过50字",
"how": "不超过50字",
"why": "不超过50字",
"what": "不超过50字"
}}
}}
"""
MERGE_ONLY_PROMPT = """你是一名“对话历史整理员”。
请将以下两段对话摘要A B合并为一段连贯的更新后的对话历史摘要
摘要 A 是较早的时间段摘要 B 是较新的时间段
摘要 A ()
{summary_a}
摘要 B ()
{summary_b}
## 要求
1. 保持时间线的连贯性将新发生的事自然接续在旧事之后
2. 最终摘要不得超过 1000
3. 依然提取 who / how / why / what 四个关键要素基于合并后的全貌
4. 禁止臆测只基于提供的摘要内容
请严格输出下列 JSON
{{
"summary": "合并后的连贯摘要",
"table": {{
"who": "不超过50字",
"how": "不超过50字",
"why": "不超过50字",
"what": "不超过50字"
}}
}}
"""
@dataclass
class HistorySummary:
summary: str
table: dict[str, str]
@dataclass(slots=True)
class ChatMessage:
role: str
content: str
timestamp: Optional[Any] = None
def formatted(self) -> str:
return f"{self.role}: {self.content}"
class HistorySummarizer:
def __init__(
self,
*,
client: Optional[Any] = None,
model: str = "gpt-4.1-mini",
max_output_tokens: int = 800,
temperature: float = 0.1,
max_messages: int = 120,
) -> None:
if client is None:
if OpenAI is None:
log.warning("OpenAI client not available. Install openai>=1.0.0.")
else:
try:
# 尝试从配置获取 API Key 和 Base URL
api_keys = OPENAI_API_KEYS.value if hasattr(OPENAI_API_KEYS, "value") else []
base_urls = OPENAI_API_BASE_URLS.value if hasattr(OPENAI_API_BASE_URLS, "value") else []
api_key = api_keys[0] if api_keys else os.environ.get("OPENAI_API_KEY")
base_url = base_urls[0] if base_urls else os.environ.get("OPENAI_API_BASE_URL")
if api_key:
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
client = OpenAI(**kwargs)
else:
log.warning("No OpenAI API key found.")
except Exception as e:
log.warning(f"Failed to init OpenAI client: {e}")
self._client = client
self._model = model
self._max_output_tokens = max_output_tokens
self._temperature = temperature
self._max_messages = max_messages
def summarize(
self,
messages: Sequence[Dict],
*,
existing_summary: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> Optional[HistorySummary]:
if not messages and not existing_summary:
return None
# 转换 dict 消息为 ChatMessage 格式用于 prompt 生成
# 确保消息按时间戳排序,防止乱序导致切片错误
sorted_messages = sorted(messages, key=lambda m: m.get('timestamp', 0) if isinstance(m.get('timestamp'), (int, float)) else 0)
# 如果有 existing_summary我们可以适当减少这里的消息量或者依然取最近的
# 但为了逻辑简单,我们还是取最近的 max_messages
trail = sorted_messages[-self._max_messages :]
transcript = "\n".join(f"{m.get('role', 'user')}: {m.get('content', '')}" for m in trail)
prompt = SUMMARY_PROMPT.format(
existing_summary=existing_summary.strip() if existing_summary else "",
chat_transcript=transcript,
)
if not self._client:
log.error("No OpenAI client available for summarization.")
return None
log.info(f"Starting summary generation for {len(messages)} messages...")
# Try primary client first
try:
# 增加 max_tokens 限制,避免摘要过长被截断,同时留给 JSON 结构足够的空间
# 根据经验1000 字摘要 + JSON 结构大约需要 1500 tokens
safe_max_tokens = max(max_tokens or self._max_output_tokens, 2000)
response = self._client.chat.completions.create(
model=self._model,
messages=[{"role": "user", "content": prompt}],
max_tokens=safe_max_tokens,
temperature=self._temperature,
)
# Debug: Print full response to investigate empty content issues
log.info(f"Full Summary API Response: {response}")
payload = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason
if finish_reason == "length":
log.warning("Summary generation was truncated due to length limit!")
log.info(f"Summary generation completed. Payload length: {len(payload)}")
log.info(f"Summary Content:\n{payload}")
return self._parse_response(payload)
except Exception as e:
log.warning(f"Summarization failed: {e}")
return None
def merge_summaries(
self,
summary_a: str,
summary_b: str,
*,
max_tokens: Optional[int] = None,
) -> Optional[HistorySummary]:
if not summary_a and not summary_b:
return None
prompt = MERGE_ONLY_PROMPT.format(
summary_a=summary_a or "",
summary_b=summary_b or "",
)
if not self._client:
return None
log.info(f"Starting summary merge (A len={len(summary_a)}, B len={len(summary_b)})...")
# Try primary client
try:
response = self._client.chat.completions.create(
model=self._model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens or self._max_output_tokens,
temperature=self._temperature,
)
payload = response.choices[0].message.content or ""
log.info("Summary merge completed successfully.")
return self._parse_response(payload)
except Exception as e:
log.warning(f"Merge failed: {e}")
return None
def _parse_response(self, payload: str) -> HistorySummary:
data = _safe_json_loads(payload)
# 如果解析出的 data 是空或者不是 dict尝试直接用 payload
if not isinstance(data, dict) or (not data and not payload.strip().startswith("{")):
summary = payload.strip()
table = {}
else:
summary = str(data.get("summary", "")).strip()
table_payload = data.get("table", {}) or {}
table = {
"who": str(table_payload.get("who", "")).strip(),
"how": str(table_payload.get("how", "")).strip(),
"why": str(table_payload.get("why", "")).strip(),
"what": str(table_payload.get("what", "")).strip(),
}
if not summary:
summary = payload.strip()
if len(summary) > 1000:
summary = summary[:1000].rstrip() + "..."
return HistorySummary(summary=summary, table=table)
def _safe_json_loads(raw: str) -> Dict[str, Any]:
try:
return json.loads(raw)
except json.JSONDecodeError:
# 简单的正则提取尝试
match = re.search(r'(\{.*\})', raw, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
return {}
# --- Core Logic Modules ---
def build_ordered_messages(
messages_map: Optional[Dict], anchor_id: Optional[str] = None
) -> List[Dict]:
"""
将消息 map 还原为有序列表
策略
1. 优先基于 parentId 链条追溯 anchor_id 向上回溯到根消息
2. 退化按时间戳排序 anchor_id 或追溯失败时
参数
messages_map: 消息 map格式 {"msg-id": {"role": "user", "content": "...", "parentId": "...", "timestamp": 123456}}
anchor_id: 锚点消息 ID链尾从此消息向上追溯
返回
有序的消息列表每个消息包含 id 字段
"""
if not messages_map:
return []
# 补齐消息的 id 字段
def with_id(message_id: str, message: Dict) -> Dict:
return {**message, **({"id": message_id} if "id" not in message else {})}
# 模式 1基于 parentId 链条追溯
if anchor_id and anchor_id in messages_map:
ordered: List[Dict] = []
current_id: Optional[str] = anchor_id
while current_id:
current_msg = messages_map.get(current_id)
if not current_msg:
break
ordered.insert(0, with_id(current_id, current_msg))
current_id = current_msg.get("parentId")
return ordered
# 模式 2基于时间戳排序
sortable: List[Tuple[int, str, Dict]] = []
for mid, message in messages_map.items():
ts = (
message.get("createdAt")
or message.get("created_at")
or message.get("timestamp")
or 0
)
sortable.append((int(ts), mid, message))
sortable.sort(key=lambda x: x[0])
return [with_id(mid, msg) for _, mid, msg in sortable]
def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]:
"""
获取指定用户的最近 N 条消息优先当前会话然后按时间顺序
参数
user_id: 用户 ID
chat_id: 当前会话 ID用于优先提取
num: 需要获取的消息数量<= 0 时返回全部
返回
有序的消息列表优先当前会话不足时由全局最近补齐
"""
current_chat_messages: List[Dict] = []
other_messages: List[Dict] = []
# 遍历用户的所有聊天
chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True)
for chat in chats:
messages_map = chat.chat.get("history", {}).get("messages", {}) or {}
# 简单判断是否为当前会话
is_current_chat = (str(chat.id) == str(chat_id))
for mid, msg in messages_map.items():
# 跳过空内容
if msg.get("content", "") == "":
continue
ts = (
msg.get("createdAt")
or msg.get("created_at")
or msg.get("timestamp")
or 0
)
entry = {**msg, "id": mid}
entry.setdefault("chat_id", chat.id)
entry.setdefault("timestamp", int(ts))
if is_current_chat:
current_chat_messages.append(entry)
else:
other_messages.append(entry)
# 分别排序
current_chat_messages.sort(key=lambda m: m.get("timestamp", 0))
other_messages.sort(key=lambda m: m.get("timestamp", 0))
if num <= 0:
combined = current_chat_messages + other_messages
combined.sort(key=lambda m: m.get("timestamp", 0))
return combined
# 策略:优先保留当前会话消息
if len(current_chat_messages) >= num:
return current_chat_messages[-num:]
# 补充不足的部分
needed = num - len(current_chat_messages)
supplement = other_messages[-needed:] if other_messages else []
# 合并并最终按时间排序
final_list = supplement + current_chat_messages
final_list.sort(key=lambda m: m.get("timestamp", 0))
return final_list
def slice_messages_with_summary(
messages_map: Dict,
boundary_message_id: Optional[str],
anchor_id: Optional[str],
pre_boundary: int = 20,
) -> List[Dict]:
"""
基于摘要边界裁剪消息列表返回摘要前 N + 摘要后全部消息
策略保留摘要边界前 N 条消息提供上下文+ 摘要后全部消息最新对话
目的降低 token 消耗同时保留足够的上下文信息
参数
messages_map: 消息 map
boundary_message_id: 摘要边界消息 IDNone 时返回全量消息
anchor_id: 锚点消息 ID链尾
pre_boundary: 摘要边界前保留的消息数量默认 20
返回
裁剪后的有序消息列表
示例
100 条消息摘要边界在第 50 pre_boundary=20
返回消息 29-99 71
"""
ordered = build_ordered_messages(messages_map, anchor_id)
if boundary_message_id:
try:
# 查找摘要边界消息的索引
boundary_idx = next(
idx for idx, msg in enumerate(ordered) if msg.get("id") == boundary_message_id
)
# 计算裁剪起点
start_idx = max(boundary_idx - pre_boundary, 0)
ordered = ordered[start_idx:]
except StopIteration:
# 边界消息不存在,返回全量
pass
return ordered
def summarize(messages: List[Dict], old_summary: Optional[str] = None, model: Optional[str] = None) -> str:
"""
生成对话摘要
参数
messages: 需要摘要的消息列表
old_summary: 旧摘要
model: 指定使用的模型 ID如果为 None则使用类内部默认值
返回
摘要字符串
"""
summarizer = HistorySummarizer(model=model) if model else HistorySummarizer()
result = summarizer.summarize(messages, existing_summary=old_summary)
return result.summary if result else ""
def compute_token_count(messages: List[Dict]) -> int:
"""
计算消息的 token 数量占位实现
当前算法4 字符 1 token粗略估算
TODO接入真实 tokenizer tiktoken for OpenAI models
"""
total_chars = 0
for msg in messages:
content = msg.get('content')
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and 'text' in item:
total_chars += len(item['text'])
return max(total_chars // 4, 0)