open-webui/backend/open_webui/utils/summary.py

172 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, List, Optional, Tuple
from open_webui.models.chats import Chats
def build_ordered_messages(
messages_map: Optional[Dict], anchor_id: Optional[str] = None
) -> List[Dict]:
"""
将消息 map 还原为有序列表
策略:
1. 优先:基于 parentId 链条追溯(从 anchor_id 向上回溯到根消息)
2. 退化:按时间戳排序(无 anchor_id 或追溯失败时)
参数:
messages_map: 消息 map格式 {"msg-id": {"role": "user", "content": "...", "parentId": "...", "timestamp": 123456}}
anchor_id: 锚点消息 ID链尾从此消息向上追溯
返回:
有序的消息列表,每个消息包含 id 字段
"""
if not messages_map:
return []
# 补齐消息的 id 字段
def with_id(message_id: str, message: Dict) -> Dict:
return {**message, **({"id": message_id} if "id" not in message else {})}
# 模式 1基于 parentId 链条追溯
if anchor_id and anchor_id in messages_map:
ordered: List[Dict] = []
current_id: Optional[str] = anchor_id
while current_id:
current_msg = messages_map.get(current_id)
if not current_msg:
break
ordered.insert(0, with_id(current_id, current_msg))
current_id = current_msg.get("parentId")
return ordered
# 模式 2基于时间戳排序
sortable: List[Tuple[int, str, Dict]] = []
for mid, message in messages_map.items():
ts = (
message.get("createdAt")
or message.get("created_at")
or message.get("timestamp")
or 0
)
sortable.append((int(ts), mid, message))
sortable.sort(key=lambda x: x[0])
return [with_id(mid, msg) for _, mid, msg in sortable]
def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]:
"""
获取指定用户的全局最近 N 条消息(按时间顺序)
参数:
user_id: 用户 ID
num: 需要获取的消息数量(<= 0 时返回全部)
返回:
有序的消息列表(最近的 num 条)
"""
all_messages: List[Dict] = []
# 遍历用户的所有聊天
chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True)
for chat in chats:
messages_map = chat.chat.get("history", {}).get("messages", {}) or {}
for mid, msg in messages_map.items():
# 跳过空内容
if msg.get("content", "") == "":
continue
ts = (
msg.get("createdAt")
or msg.get("created_at")
or msg.get("timestamp")
or 0
)
entry = {**msg, "id": mid}
entry.setdefault("chat_id", chat.id)
entry.setdefault("timestamp", int(ts))
all_messages.append(entry)
# 按时间戳排序
all_messages.sort(key=lambda m: m.get("timestamp", 0))
if num <= 0:
return all_messages
return all_messages[-num:]
def slice_messages_with_summary(
messages_map: Dict,
boundary_message_id: Optional[str],
anchor_id: Optional[str],
pre_boundary: int = 20,
) -> List[Dict]:
"""
基于摘要边界裁剪消息列表(返回摘要前 N 条 + 摘要后全部消息)
策略:保留摘要边界前 N 条消息(提供上下文)+ 摘要后全部消息(最新对话)
目的:降低 token 消耗,同时保留足够的上下文信息
参数:
messages_map: 消息 map
boundary_message_id: 摘要边界消息 IDNone 时返回全量消息)
anchor_id: 锚点消息 ID链尾
pre_boundary: 摘要边界前保留的消息数量(默认 20
返回:
裁剪后的有序消息列表
示例:
100 条消息,摘要边界在第 50 条pre_boundary=20
→ 返回消息 29-99共 71 条)
"""
ordered = build_ordered_messages(messages_map, anchor_id)
if boundary_message_id:
try:
# 查找摘要边界消息的索引
boundary_idx = next(
idx for idx, msg in enumerate(ordered) if msg.get("id") == boundary_message_id
)
# 计算裁剪起点
start_idx = max(boundary_idx - pre_boundary, 0)
ordered = ordered[start_idx:]
except StopIteration:
# 边界消息不存在,返回全量
pass
return ordered
def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str:
"""
生成对话摘要(占位接口)
参数:
messages: 需要摘要的消息列表
old_summary: 旧摘要(可选,当前未使用)
返回:
摘要字符串
TODO
- 实现增量摘要逻辑(基于 old_summary 生成新摘要)
- 支持摘要策略配置(长度、详细程度)
"""
return "\n".join(m.get("content")[:100] for m in messages)
def compute_token_count(messages: List[Dict]) -> int:
"""
计算消息的 token 数量(占位实现)
当前算法4 字符 ≈ 1 token粗略估算
TODO接入真实 tokenizer如 tiktoken for OpenAI models
"""
total_chars = 0
for msg in messages:
total_chars += len(msg['content'])
return max(total_chars // 4, 0)