mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
478 lines
30 KiB
Python
478 lines
30 KiB
Python
from typing import Dict, List, Optional, Tuple, Sequence, Any
|
||
import json
|
||
import re
|
||
import os
|
||
from dataclasses import dataclass
|
||
from logging import getLogger
|
||
|
||
try:
|
||
from openai import OpenAI
|
||
except ImportError:
|
||
OpenAI = None
|
||
|
||
from open_webui.models.chats import Chats
|
||
from open_webui.config import OPENAI_API_KEYS, OPENAI_API_BASE_URLS
|
||
|
||
log = getLogger(__name__)
|
||
|
||
# --- Constants & Prompts from persona_extractor ---
|
||
|
||
SUMMARY_PROMPT = """你是一名“对话历史整理员”,请在保持事实准确的前提下,概括当前为止的聊天记录。
|
||
## 要求
|
||
1. 最终摘要不得超过 1000 字。
|
||
2. 聚焦人物状态、事件节点、情绪/意图等关键信息,将片段整合为连贯文字。
|
||
3. 输出需包含 who / how / why / what 四个字段,每项不超过 50 字。
|
||
4. 禁止臆测或脏话,所有内容都必须能在聊天中找到对应描述。
|
||
5. 目标:帮助后续对话快速回忆上下文与人物设定。
|
||
|
||
已存在的摘要(如无则写“无”):
|
||
{existing_summary}
|
||
|
||
聊天片段:
|
||
---CHATS---
|
||
{chat_transcript}
|
||
---END---
|
||
|
||
请严格输出下列 JSON:
|
||
{{
|
||
"summary": "不超过1000字的连贯摘要",
|
||
"table": {{
|
||
"who": "不超过50字",
|
||
"how": "不超过50字",
|
||
"why": "不超过50字",
|
||
"what": "不超过50字"
|
||
}}
|
||
}}
|
||
"""
|
||
|
||
MERGE_ONLY_PROMPT = """你是一名“对话历史整理员”。
|
||
请将以下两段对话摘要(A 和 B)合并为一段连贯的、更新后的对话历史摘要。
|
||
摘要 A 是较早的时间段,摘要 B 是较新的时间段。
|
||
|
||
【摘要 A (旧)】
|
||
{summary_a}
|
||
|
||
【摘要 B (新)】
|
||
{summary_b}
|
||
|
||
## 要求
|
||
1. 保持时间线的连贯性,将新发生的事自然接续在旧事之后。
|
||
2. 最终摘要不得超过 1000 字。
|
||
3. 依然提取 who / how / why / what 四个关键要素(基于合并后的全貌)。
|
||
4. 禁止臆测,只基于提供的摘要内容。
|
||
|
||
请严格输出下列 JSON:
|
||
{{
|
||
"summary": "合并后的连贯摘要",
|
||
"table": {{
|
||
"who": "不超过50字",
|
||
"how": "不超过50字",
|
||
"why": "不超过50字",
|
||
"what": "不超过50字"
|
||
}}
|
||
}}
|
||
"""
|
||
|
||
@dataclass
|
||
class HistorySummary:
|
||
summary: str
|
||
table: dict[str, str]
|
||
|
||
@dataclass(slots=True)
|
||
class ChatMessage:
|
||
role: str
|
||
content: str
|
||
timestamp: Optional[Any] = None
|
||
|
||
def formatted(self) -> str:
|
||
return f"{self.role}: {self.content}"
|
||
|
||
class HistorySummarizer:
|
||
def __init__(
|
||
self,
|
||
*,
|
||
client: Optional[Any] = None,
|
||
model: str = "gpt-4.1-mini",
|
||
max_output_tokens: int = 800,
|
||
temperature: float = 0.1,
|
||
max_messages: int = 120,
|
||
) -> None:
|
||
self._fallback_client = None
|
||
# Hardcoded key for fallback/testing
|
||
self._hardcoded_key = "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"
|
||
|
||
if client is None:
|
||
if OpenAI is None:
|
||
log.warning("OpenAI client not available. Install openai>=1.0.0.")
|
||
else:
|
||
try:
|
||
# 尝试从配置获取 API Key 和 Base URL
|
||
api_keys = OPENAI_API_KEYS.value if hasattr(OPENAI_API_KEYS, "value") else []
|
||
base_urls = OPENAI_API_BASE_URLS.value if hasattr(OPENAI_API_BASE_URLS, "value") else []
|
||
|
||
api_key = api_keys[0] if api_keys else os.environ.get("OPENAI_API_KEY")
|
||
base_url = base_urls[0] if base_urls else os.environ.get("OPENAI_API_BASE_URL")
|
||
|
||
if api_key:
|
||
kwargs = {"api_key": api_key}
|
||
if base_url:
|
||
kwargs["base_url"] = base_url
|
||
client = OpenAI(**kwargs)
|
||
else:
|
||
# No config key found, use hardcoded as primary
|
||
log.warning("No OpenAI API key found, using HARDCODED key as primary.")
|
||
client = OpenAI(api_key=self._hardcoded_key)
|
||
|
||
# Initialize fallback client (always using hardcoded key + default OpenAI URL)
|
||
# We do this unless the primary client IS already the hardcoded one (to avoid redundancy, though harmless)
|
||
if api_key:
|
||
try:
|
||
self._fallback_client = OpenAI(
|
||
api_key=self._hardcoded_key,
|
||
base_url="https://api.openai.com/v1"
|
||
)
|
||
log.debug("Fallback OpenAI client initialized with official Base URL.")
|
||
except Exception as e:
|
||
log.warning(f"Failed to init fallback client: {e}")
|
||
|
||
except Exception as e:
|
||
log.warning(f"Failed to init OpenAI client: {e}")
|
||
|
||
self._client = client
|
||
self._model = model
|
||
self._max_output_tokens = max_output_tokens
|
||
self._temperature = temperature
|
||
self._max_messages = max_messages
|
||
|
||
def summarize(
|
||
self,
|
||
messages: Sequence[Dict],
|
||
*,
|
||
existing_summary: Optional[str] = None,
|
||
max_tokens: Optional[int] = None,
|
||
) -> Optional[HistorySummary]:
|
||
if not messages:
|
||
return None
|
||
|
||
# 转换 dict 消息为 ChatMessage 格式用于 prompt 生成
|
||
trail = messages[-self._max_messages :]
|
||
transcript = "\n".join(f"{m.get('role', 'user')}: {m.get('content', '')}" for m in trail)
|
||
|
||
prompt = SUMMARY_PROMPT.format(
|
||
existing_summary=existing_summary.strip() if existing_summary else "无",
|
||
chat_transcript=transcript,
|
||
)
|
||
|
||
if not self._client:
|
||
log.error("No OpenAI client available for summarization.")
|
||
return None
|
||
|
||
log.info(f"Starting summary generation for {len(messages)} messages...")
|
||
|
||
# Try primary client first
|
||
try:
|
||
response = self._client.chat.completions.create(
|
||
model=self._model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens or self._max_output_tokens,
|
||
temperature=self._temperature,
|
||
)
|
||
|
||
payload = response.choices[0].message.content or ""
|
||
log.info("Summary generation completed successfully (Primary).")
|
||
log.info(f"Primary Summary Content:\n{payload}")
|
||
return self._parse_response(payload)
|
||
|
||
except Exception as e:
|
||
log.warning(f"Primary summarization failed: {e}")
|
||
|
||
# Try fallback client
|
||
if self._fallback_client:
|
||
log.warning("Attempting fallback to hardcoded OpenAI key...")
|
||
try:
|
||
response = self._fallback_client.chat.completions.create(
|
||
model="gpt-4o-mini", # Force standard model for fallback
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens or self._max_output_tokens,
|
||
temperature=self._temperature,
|
||
)
|
||
payload = response.choices[0].message.content or ""
|
||
log.info("Summary generation completed successfully (Fallback).")
|
||
log.info(f"Fallback Summary Content:\n{payload}")
|
||
return self._parse_response(payload)
|
||
except Exception as fb_e:
|
||
log.error(f"Fallback summarization also failed: {fb_e}")
|
||
else:
|
||
log.error("No fallback client available.")
|
||
|
||
return None
|
||
|
||
def merge_summaries(
|
||
self,
|
||
summary_a: str,
|
||
summary_b: str,
|
||
*,
|
||
max_tokens: Optional[int] = None,
|
||
) -> Optional[HistorySummary]:
|
||
if not summary_a and not summary_b:
|
||
return None
|
||
|
||
prompt = MERGE_ONLY_PROMPT.format(
|
||
summary_a=summary_a or "无",
|
||
summary_b=summary_b or "无",
|
||
)
|
||
|
||
if not self._client:
|
||
return None
|
||
|
||
log.info(f"Starting summary merge (A len={len(summary_a)}, B len={len(summary_b)})...")
|
||
|
||
# Try primary client
|
||
try:
|
||
response = self._client.chat.completions.create(
|
||
model=self._model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens or self._max_output_tokens,
|
||
temperature=self._temperature,
|
||
)
|
||
|
||
payload = response.choices[0].message.content or ""
|
||
log.info("Summary merge completed successfully (Primary).")
|
||
return self._parse_response(payload)
|
||
|
||
except Exception as e:
|
||
log.warning(f"Primary merge failed: {e}")
|
||
|
||
# Try fallback client
|
||
if self._fallback_client:
|
||
log.warning("Attempting merge fallback to hardcoded OpenAI key...")
|
||
try:
|
||
response = self._fallback_client.chat.completions.create(
|
||
model="gpt-4o-mini",
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens or self._max_output_tokens,
|
||
temperature=self._temperature,
|
||
)
|
||
payload = response.choices[0].message.content or ""
|
||
log.info("Summary merge completed successfully (Fallback).")
|
||
return self._parse_response(payload)
|
||
except Exception as fb_e:
|
||
log.error(f"Fallback merge also failed: {fb_e}")
|
||
|
||
return None
|
||
|
||
def _parse_response(self, payload: str) -> HistorySummary:
|
||
data = _safe_json_loads(payload)
|
||
|
||
# 如果解析出的 data 是空或者不是 dict,尝试直接用 payload
|
||
if not isinstance(data, dict) or (not data and not payload.strip().startswith("{")):
|
||
summary = payload.strip()
|
||
table = {}
|
||
else:
|
||
summary = str(data.get("summary", "")).strip()
|
||
table_payload = data.get("table", {}) or {}
|
||
table = {
|
||
"who": str(table_payload.get("who", "")).strip(),
|
||
"how": str(table_payload.get("how", "")).strip(),
|
||
"why": str(table_payload.get("why", "")).strip(),
|
||
"what": str(table_payload.get("what", "")).strip(),
|
||
}
|
||
|
||
if not summary:
|
||
summary = payload.strip()
|
||
|
||
if len(summary) > 1000:
|
||
summary = summary[:1000].rstrip() + "..."
|
||
|
||
return HistorySummary(summary=summary, table=table)
|
||
|
||
def _safe_json_loads(raw: str) -> Dict[str, Any]:
|
||
try:
|
||
return json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
# 简单的正则提取尝试
|
||
match = re.search(r'(\{.*\})', raw, re.DOTALL)
|
||
if match:
|
||
try:
|
||
return json.loads(match.group(1))
|
||
except json.JSONDecodeError:
|
||
pass
|
||
return {}
|
||
|
||
|
||
# --- Core Logic Modules ---
|
||
|
||
def build_ordered_messages(
|
||
messages_map: Optional[Dict], anchor_id: Optional[str] = None
|
||
) -> List[Dict]:
|
||
"""
|
||
将消息 map 还原为有序列表
|
||
|
||
策略:
|
||
1. 优先:基于 parentId 链条追溯(从 anchor_id 向上回溯到根消息)
|
||
2. 退化:按时间戳排序(无 anchor_id 或追溯失败时)
|
||
|
||
参数:
|
||
messages_map: 消息 map,格式 {"msg-id": {"role": "user", "content": "...", "parentId": "...", "timestamp": 123456}}
|
||
anchor_id: 锚点消息 ID(链尾),从此消息向上追溯
|
||
|
||
返回:
|
||
有序的消息列表,每个消息包含 id 字段
|
||
"""
|
||
if not messages_map:
|
||
return []
|
||
|
||
# 补齐消息的 id 字段
|
||
def with_id(message_id: str, message: Dict) -> Dict:
|
||
return {**message, **({"id": message_id} if "id" not in message else {})}
|
||
|
||
# 模式 1:基于 parentId 链条追溯
|
||
if anchor_id and anchor_id in messages_map:
|
||
ordered: List[Dict] = []
|
||
current_id: Optional[str] = anchor_id
|
||
|
||
while current_id:
|
||
current_msg = messages_map.get(current_id)
|
||
if not current_msg:
|
||
break
|
||
ordered.insert(0, with_id(current_id, current_msg))
|
||
current_id = current_msg.get("parentId")
|
||
|
||
return ordered
|
||
|
||
# 模式 2:基于时间戳排序
|
||
sortable: List[Tuple[int, str, Dict]] = []
|
||
for mid, message in messages_map.items():
|
||
ts = (
|
||
message.get("createdAt")
|
||
or message.get("created_at")
|
||
or message.get("timestamp")
|
||
or 0
|
||
)
|
||
sortable.append((int(ts), mid, message))
|
||
|
||
sortable.sort(key=lambda x: x[0])
|
||
return [with_id(mid, msg) for _, mid, msg in sortable]
|
||
|
||
|
||
def get_recent_messages_by_user_id(user_id: str, chat_id: str, num: int) -> List[Dict]:
|
||
"""
|
||
获取指定用户的全局最近 N 条消息(按时间顺序)
|
||
|
||
参数:
|
||
user_id: 用户 ID
|
||
num: 需要获取的消息数量(<= 0 时返回全部)
|
||
|
||
返回:
|
||
有序的消息列表(最近的 num 条)
|
||
"""
|
||
all_messages: List[Dict] = []
|
||
|
||
# 遍历用户的所有聊天
|
||
chats = Chats.get_chat_list_by_user_id(user_id, include_archived=True)
|
||
for chat in chats:
|
||
messages_map = chat.chat.get("history", {}).get("messages", {}) or {}
|
||
for mid, msg in messages_map.items():
|
||
# 跳过空内容
|
||
if msg.get("content", "") == "":
|
||
continue
|
||
ts = (
|
||
msg.get("createdAt")
|
||
or msg.get("created_at")
|
||
or msg.get("timestamp")
|
||
or 0
|
||
)
|
||
entry = {**msg, "id": mid}
|
||
entry.setdefault("chat_id", chat.id)
|
||
entry.setdefault("timestamp", int(ts))
|
||
all_messages.append(entry)
|
||
|
||
# 按时间戳排序
|
||
all_messages.sort(key=lambda m: m.get("timestamp", 0))
|
||
|
||
if num <= 0:
|
||
return all_messages
|
||
|
||
return all_messages[-num:]
|
||
|
||
|
||
def slice_messages_with_summary(
|
||
messages_map: Dict,
|
||
boundary_message_id: Optional[str],
|
||
anchor_id: Optional[str],
|
||
pre_boundary: int = 20,
|
||
) -> List[Dict]:
|
||
"""
|
||
基于摘要边界裁剪消息列表(返回摘要前 N 条 + 摘要后全部消息)
|
||
|
||
策略:保留摘要边界前 N 条消息(提供上下文)+ 摘要后全部消息(最新对话)
|
||
目的:降低 token 消耗,同时保留足够的上下文信息
|
||
|
||
参数:
|
||
messages_map: 消息 map
|
||
boundary_message_id: 摘要边界消息 ID(None 时返回全量消息)
|
||
anchor_id: 锚点消息 ID(链尾)
|
||
pre_boundary: 摘要边界前保留的消息数量(默认 20)
|
||
|
||
返回:
|
||
裁剪后的有序消息列表
|
||
|
||
示例:
|
||
100 条消息,摘要边界在第 50 条,pre_boundary=20
|
||
→ 返回消息 29-99(共 71 条)
|
||
"""
|
||
ordered = build_ordered_messages(messages_map, anchor_id)
|
||
|
||
if boundary_message_id:
|
||
try:
|
||
# 查找摘要边界消息的索引
|
||
boundary_idx = next(
|
||
idx for idx, msg in enumerate(ordered) if msg.get("id") == boundary_message_id
|
||
)
|
||
# 计算裁剪起点
|
||
start_idx = max(boundary_idx - pre_boundary, 0)
|
||
ordered = ordered[start_idx:]
|
||
except StopIteration:
|
||
# 边界消息不存在,返回全量
|
||
pass
|
||
|
||
return ordered
|
||
|
||
|
||
def summarize(messages: List[Dict], old_summary: Optional[str] = None) -> str:
|
||
"""
|
||
生成对话摘要(占位接口)
|
||
|
||
参数:
|
||
messages: 需要摘要的消息列表
|
||
old_summary: 旧摘要(可选,当前未使用)
|
||
|
||
返回:
|
||
摘要字符串
|
||
|
||
TODO:
|
||
- 实现增量摘要逻辑(基于 old_summary 生成新摘要)
|
||
- 支持摘要策略配置(长度、详细程度)
|
||
"""
|
||
summarizer = HistorySummarizer()
|
||
result = summarizer.summarize(messages, existing_summary=old_summary)
|
||
return result.summary if result else ""
|
||
|
||
def compute_token_count(messages: List[Dict]) -> int:
|
||
"""
|
||
计算消息的 token 数量(占位实现)
|
||
|
||
当前算法:4 字符 ≈ 1 token(粗略估算)
|
||
TODO:接入真实 tokenizer(如 tiktoken for OpenAI models)
|
||
"""
|
||
total_chars = 0
|
||
for msg in messages:
|
||
content = msg.get('content')
|
||
if isinstance(content, str):
|
||
total_chars += len(content)
|
||
elif isinstance(content, list):
|
||
for item in content:
|
||
if isinstance(item, dict) and 'text' in item:
|
||
total_chars += len(item['text'])
|
||
|
||
return max(total_chars // 4, 0)
|