open-webui/backend/open_webui/utils/summary.py

173 lines
5.3 KiB
Python
Raw Normal View History

from typing import Dict, List, Optional, Tuple
from open_webui.models.chats import Chats
def build_ordered_messages(
messages_map: Optional[Dict], anchor_id: Optional[str] = None
) -> List[Dict]:
"""
将消息 map 还原为有序列表
策略
1. 优先基于 parentId 链条追溯 anchor_id 向上回溯到根消息
2. 退化按时间戳排序 anchor_id 或追溯失败时
参数
messages_map: 消息 map格式 {"msg-id": {"role": "user", "content": "...", "parentId": "...", "timestamp": 123456}}
anchor_id: 锚点消息 ID链尾从此消息向上追溯
返回
有序的消息列表每个消息包含 id 字段
"""
if not messages_map:
return []
# 补齐消息的 id 字段
def with_id(message_id: str, message: Dict) -> Dict:
return {**message, **({"id": message_id} if "id" not in message else {})}
# 模式 1基于 parentId 链条追溯
if anchor_id and anchor_id in messages_map:
ordered: List[Dict] = []
current_id: Optional[str] = anchor_id
while current_id:
current_msg = messages_map.get(current_id)
if not current_msg:
break
ordered.insert(0, with_id(current_id, current_msg))
current_id = current_msg.get("parentId")
return ordered
# 模式 2基于时间戳排序
sortable: List[Tuple[int, str, Dict]] = []
for mid, message in messages_map.items():
ts = (
message.get("createdAt")
or message.get("created_at")
or message.get("timestamp")
or 0
)
sortable.append((int(ts), mid, message))
sortable.sort(key=lambda x: x[0])
return [with_id(mid, msg) for _, mid, msg in sortable]
def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]:
"""
获取指定用户的全局最近 N 条消息按时间顺序
参数
user_id: 用户 ID
num: 需要获取的消息数量<= 0 时返回全部
返回
有序的消息列表最近的 num
"""
all_messages: List[Dict] = []
# 遍历用户的所有聊天
chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True)
for chat in chats:
messages_map = chat.chat.get("history", {}).get("messages", {}) or {}
for mid, msg in messages_map.items():
# 跳过空内容
if msg.get("content", "") == "":
continue
ts = (
msg.get("createdAt")
or msg.get("created_at")
or msg.get("timestamp")
or 0
)
entry = {**msg, "id": mid}
entry.setdefault("chat_id", chat.id)
entry.setdefault("timestamp", int(ts))
all_messages.append(entry)
# 按时间戳排序
all_messages.sort(key=lambda m: m.get("timestamp", 0))
if num <= 0:
return all_messages
return all_messages[-num:]
def slice_messages_with_summary(
messages_map: Dict,
boundary_message_id: Optional[str],
anchor_id: Optional[str],
pre_boundary: int = 20,
) -> List[Dict]:
"""
基于摘要边界裁剪消息列表返回摘要前 N + 摘要后全部消息
策略保留摘要边界前 N 条消息提供上下文+ 摘要后全部消息最新对话
目的降低 token 消耗同时保留足够的上下文信息
参数
messages_map: 消息 map
boundary_message_id: 摘要边界消息 IDNone 时返回全量消息
anchor_id: 锚点消息 ID链尾
pre_boundary: 摘要边界前保留的消息数量默认 20
返回
裁剪后的有序消息列表
示例
100 条消息摘要边界在第 50 pre_boundary=20
返回消息 29-99 71
"""
ordered = build_ordered_messages(messages_map, anchor_id)
if boundary_message_id:
try:
# 查找摘要边界消息的索引
boundary_idx = next(
idx for idx, msg in enumerate(ordered) if msg.get("id") == boundary_message_id
)
# 计算裁剪起点
start_idx = max(boundary_idx - pre_boundary, 0)
ordered = ordered[start_idx:]
except StopIteration:
# 边界消息不存在,返回全量
pass
return ordered
def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str:
"""
生成对话摘要占位接口
参数
messages: 需要摘要的消息列表
old_summary: 旧摘要可选当前未使用
返回
摘要字符串
TODO
- 实现增量摘要逻辑基于 old_summary 生成新摘要
- 支持摘要策略配置长度详细程度
"""
return "\n".join(m.get("content")[:100] for m in messages)
def compute_token_count(messages: List[Dict]) -> int:
"""
计算消息的 token 数量占位实现
当前算法4 字符 1 token粗略估算
TODO接入真实 tokenizer tiktoken for OpenAI models
"""
total_chars = 0
for msg in messages:
total_chars += len(msg['content'])
return max(total_chars // 4, 0)