open-webui/backend/open_webui/utils/summary.py

478 lines
30 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, List, Optional, Tuple, Sequence, Any
import json
import re
import os
from dataclasses import dataclass
from logging import getLogger
try:
from openai import OpenAI
except ImportError:
OpenAI = None
from open_webui.models.chats import Chats
from open_webui.config import OPENAI_API_KEYS, OPENAI_API_BASE_URLS
log = getLogger(__name__)
# --- Constants & Prompts from persona_extractor ---
SUMMARY_PROMPT = """你是一名“对话历史整理员”,请在保持事实准确的前提下,概括当前为止的聊天记录。
## 要求
1. 最终摘要不得超过 1000 字。
2. 聚焦人物状态、事件节点、情绪/意图等关键信息,将片段整合为连贯文字。
3. 输出需包含 who / how / why / what 四个字段,每项不超过 50 字。
4. 禁止臆测或脏话,所有内容都必须能在聊天中找到对应描述。
5. 目标:帮助后续对话快速回忆上下文与人物设定。
已存在的摘要(如无则写“无”):
{existing_summary}
聊天片段:
---CHATS---
{chat_transcript}
---END---
请严格输出下列 JSON
{{
"summary": "不超过1000字的连贯摘要",
"table": {{
"who": "不超过50字",
"how": "不超过50字",
"why": "不超过50字",
"what": "不超过50字"
}}
}}
"""
MERGE_ONLY_PROMPT = """你是一名“对话历史整理员”。
请将以下两段对话摘要A 和 B合并为一段连贯的、更新后的对话历史摘要。
摘要 A 是较早的时间段,摘要 B 是较新的时间段。
【摘要 A (旧)】
{summary_a}
【摘要 B (新)】
{summary_b}
## 要求
1. 保持时间线的连贯性,将新发生的事自然接续在旧事之后。
2. 最终摘要不得超过 1000 字。
3. 依然提取 who / how / why / what 四个关键要素(基于合并后的全貌)。
4. 禁止臆测,只基于提供的摘要内容。
请严格输出下列 JSON
{{
"summary": "合并后的连贯摘要",
"table": {{
"who": "不超过50字",
"how": "不超过50字",
"why": "不超过50字",
"what": "不超过50字"
}}
}}
"""
@dataclass
class HistorySummary:
summary: str
table: dict[str, str]
@dataclass(slots=True)
class ChatMessage:
role: str
content: str
timestamp: Optional[Any] = None
def formatted(self) -> str:
return f"{self.role}: {self.content}"
class HistorySummarizer:
def __init__(
self,
*,
client: Optional[Any] = None,
model: str = "gpt-4.1-mini",
max_output_tokens: int = 800,
temperature: float = 0.1,
max_messages: int = 120,
) -> None:
self._fallback_client = None
# Hardcoded key for fallback/testing
self._hardcoded_key = "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"
if client is None:
if OpenAI is None:
log.warning("OpenAI client not available. Install openai>=1.0.0.")
else:
try:
# 尝试从配置获取 API Key 和 Base URL
api_keys = OPENAI_API_KEYS.value if hasattr(OPENAI_API_KEYS, "value") else []
base_urls = OPENAI_API_BASE_URLS.value if hasattr(OPENAI_API_BASE_URLS, "value") else []
api_key = api_keys[0] if api_keys else os.environ.get("OPENAI_API_KEY")
base_url = base_urls[0] if base_urls else os.environ.get("OPENAI_API_BASE_URL")
if api_key:
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
client = OpenAI(**kwargs)
else:
# No config key found, use hardcoded as primary
log.warning("No OpenAI API key found, using HARDCODED key as primary.")
client = OpenAI(api_key=self._hardcoded_key)
# Initialize fallback client (always using hardcoded key + default OpenAI URL)
# We do this unless the primary client IS already the hardcoded one (to avoid redundancy, though harmless)
if api_key:
try:
self._fallback_client = OpenAI(
api_key=self._hardcoded_key,
base_url="https://api.openai.com/v1"
)
log.debug("Fallback OpenAI client initialized with official Base URL.")
except Exception as e:
log.warning(f"Failed to init fallback client: {e}")
except Exception as e:
log.warning(f"Failed to init OpenAI client: {e}")
self._client = client
self._model = model
self._max_output_tokens = max_output_tokens
self._temperature = temperature
self._max_messages = max_messages
def summarize(
self,
messages: Sequence[Dict],
*,
existing_summary: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> Optional[HistorySummary]:
if not messages:
return None
# 转换 dict 消息为 ChatMessage 格式用于 prompt 生成
trail = messages[-self._max_messages :]
transcript = "\n".join(f"{m.get('role', 'user')}: {m.get('content', '')}" for m in trail)
prompt = SUMMARY_PROMPT.format(
existing_summary=existing_summary.strip() if existing_summary else "",
chat_transcript=transcript,
)
if not self._client:
log.error("No OpenAI client available for summarization.")
return None
log.info(f"Starting summary generation for {len(messages)} messages...")
# Try primary client first
try:
response = self._client.chat.completions.create(
model=self._model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens or self._max_output_tokens,
temperature=self._temperature,
)
payload = response.choices[0].message.content or ""
log.info("Summary generation completed successfully (Primary).")
log.info(f"Primary Summary Content:\n{payload}")
return self._parse_response(payload)
except Exception as e:
log.warning(f"Primary summarization failed: {e}")
# Try fallback client
if self._fallback_client:
log.warning("Attempting fallback to hardcoded OpenAI key...")
try:
response = self._fallback_client.chat.completions.create(
model="gpt-4o-mini", # Force standard model for fallback
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens or self._max_output_tokens,
temperature=self._temperature,
)
payload = response.choices[0].message.content or ""
log.info("Summary generation completed successfully (Fallback).")
log.info(f"Fallback Summary Content:\n{payload}")
return self._parse_response(payload)
except Exception as fb_e:
log.error(f"Fallback summarization also failed: {fb_e}")
else:
log.error("No fallback client available.")
return None
def merge_summaries(
self,
summary_a: str,
summary_b: str,
*,
max_tokens: Optional[int] = None,
) -> Optional[HistorySummary]:
if not summary_a and not summary_b:
return None
prompt = MERGE_ONLY_PROMPT.format(
summary_a=summary_a or "",
summary_b=summary_b or "",
)
if not self._client:
return None
log.info(f"Starting summary merge (A len={len(summary_a)}, B len={len(summary_b)})...")
# Try primary client
try:
response = self._client.chat.completions.create(
model=self._model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens or self._max_output_tokens,
temperature=self._temperature,
)
payload = response.choices[0].message.content or ""
log.info("Summary merge completed successfully (Primary).")
return self._parse_response(payload)
except Exception as e:
log.warning(f"Primary merge failed: {e}")
# Try fallback client
if self._fallback_client:
log.warning("Attempting merge fallback to hardcoded OpenAI key...")
try:
response = self._fallback_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens or self._max_output_tokens,
temperature=self._temperature,
)
payload = response.choices[0].message.content or ""
log.info("Summary merge completed successfully (Fallback).")
return self._parse_response(payload)
except Exception as fb_e:
log.error(f"Fallback merge also failed: {fb_e}")
return None
def _parse_response(self, payload: str) -> HistorySummary:
data = _safe_json_loads(payload)
# 如果解析出的 data 是空或者不是 dict尝试直接用 payload
if not isinstance(data, dict) or (not data and not payload.strip().startswith("{")):
summary = payload.strip()
table = {}
else:
summary = str(data.get("summary", "")).strip()
table_payload = data.get("table", {}) or {}
table = {
"who": str(table_payload.get("who", "")).strip(),
"how": str(table_payload.get("how", "")).strip(),
"why": str(table_payload.get("why", "")).strip(),
"what": str(table_payload.get("what", "")).strip(),
}
if not summary:
summary = payload.strip()
if len(summary) > 1000:
summary = summary[:1000].rstrip() + "..."
return HistorySummary(summary=summary, table=table)
def _safe_json_loads(raw: str) -> Dict[str, Any]:
try:
return json.loads(raw)
except json.JSONDecodeError:
# 简单的正则提取尝试
match = re.search(r'(\{.*\})', raw, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
return {}
# --- Core Logic Modules ---
def build_ordered_messages(
messages_map: Optional[Dict], anchor_id: Optional[str] = None
) -> List[Dict]:
"""
将消息 map 还原为有序列表
策略:
1. 优先:基于 parentId 链条追溯(从 anchor_id 向上回溯到根消息)
2. 退化:按时间戳排序(无 anchor_id 或追溯失败时)
参数:
messages_map: 消息 map格式 {"msg-id": {"role": "user", "content": "...", "parentId": "...", "timestamp": 123456}}
anchor_id: 锚点消息 ID链尾从此消息向上追溯
返回:
有序的消息列表,每个消息包含 id 字段
"""
if not messages_map:
return []
# 补齐消息的 id 字段
def with_id(message_id: str, message: Dict) -> Dict:
return {**message, **({"id": message_id} if "id" not in message else {})}
# 模式 1基于 parentId 链条追溯
if anchor_id and anchor_id in messages_map:
ordered: List[Dict] = []
current_id: Optional[str] = anchor_id
while current_id:
current_msg = messages_map.get(current_id)
if not current_msg:
break
ordered.insert(0, with_id(current_id, current_msg))
current_id = current_msg.get("parentId")
return ordered
# 模式 2基于时间戳排序
sortable: List[Tuple[int, str, Dict]] = []
for mid, message in messages_map.items():
ts = (
message.get("createdAt")
or message.get("created_at")
or message.get("timestamp")
or 0
)
sortable.append((int(ts), mid, message))
sortable.sort(key=lambda x: x[0])
return [with_id(mid, msg) for _, mid, msg in sortable]
def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]:
"""
获取指定用户的全局最近 N 条消息(按时间顺序)
参数:
user_id: 用户 ID
num: 需要获取的消息数量(<= 0 时返回全部)
返回:
有序的消息列表(最近的 num 条)
"""
all_messages: List[Dict] = []
# 遍历用户的所有聊天
chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True)
for chat in chats:
messages_map = chat.chat.get("history", {}).get("messages", {}) or {}
for mid, msg in messages_map.items():
# 跳过空内容
if msg.get("content", "") == "":
continue
ts = (
msg.get("createdAt")
or msg.get("created_at")
or msg.get("timestamp")
or 0
)
entry = {**msg, "id": mid}
entry.setdefault("chat_id", chat.id)
entry.setdefault("timestamp", int(ts))
all_messages.append(entry)
# 按时间戳排序
all_messages.sort(key=lambda m: m.get("timestamp", 0))
if num <= 0:
return all_messages
return all_messages[-num:]
def slice_messages_with_summary(
messages_map: Dict,
boundary_message_id: Optional[str],
anchor_id: Optional[str],
pre_boundary: int = 20,
) -> List[Dict]:
"""
基于摘要边界裁剪消息列表(返回摘要前 N 条 + 摘要后全部消息)
策略:保留摘要边界前 N 条消息(提供上下文)+ 摘要后全部消息(最新对话)
目的:降低 token 消耗,同时保留足够的上下文信息
参数:
messages_map: 消息 map
boundary_message_id: 摘要边界消息 IDNone 时返回全量消息)
anchor_id: 锚点消息 ID链尾
pre_boundary: 摘要边界前保留的消息数量(默认 20
返回:
裁剪后的有序消息列表
示例:
100 条消息,摘要边界在第 50 条pre_boundary=20
→ 返回消息 29-99共 71 条)
"""
ordered = build_ordered_messages(messages_map, anchor_id)
if boundary_message_id:
try:
# 查找摘要边界消息的索引
boundary_idx = next(
idx for idx, msg in enumerate(ordered) if msg.get("id") == boundary_message_id
)
# 计算裁剪起点
start_idx = max(boundary_idx - pre_boundary, 0)
ordered = ordered[start_idx:]
except StopIteration:
# 边界消息不存在,返回全量
pass
return ordered
def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str:
"""
生成对话摘要(占位接口)
参数:
messages: 需要摘要的消息列表
old_summary: 旧摘要(可选,当前未使用)
返回:
摘要字符串
TODO
- 实现增量摘要逻辑(基于 old_summary 生成新摘要)
- 支持摘要策略配置(长度、详细程度)
"""
summarizer = HistorySummarizer()
result = summarizer.summarize(messages, existing_summary=old_summary)
return result.summary if result else ""
def compute_token_count(messages: List[Dict]) -> int:
"""
计算消息的 token 数量(占位实现)
当前算法4 字符 ≈ 1 token粗略估算
TODO接入真实 tokenizer如 tiktoken for OpenAI models
"""
total_chars = 0
for msg in messages:
content = msg.get('content')
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and 'text' in item:
total_chars += len(item['text'])
return max(total_chars // 4, 0)