mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-17 06:45:24 +00:00
1. 添加 mem0_search 和 mem0_delete TODO实现函数. 2. 在收到新消息时,调用 mem0_search 获得相关记忆,封装到上下文 3. 在删除聊天框时候,调用 mem0_delete 以删除相关聊天框所抽取的记忆
This commit is contained in:
parent
242bcf7c29
commit
d61bb99ffa
4 changed files with 78 additions and 29 deletions
|
|
@ -1588,6 +1588,7 @@ async def chat_completion(
|
||||||
"""处理完整的聊天流程:Payload 处理 → LLM 调用 → 响应处理"""
|
"""处理完整的聊天流程:Payload 处理 → LLM 调用 → 响应处理"""
|
||||||
try:
|
try:
|
||||||
# 8.1 Payload 预处理:执行 Pipeline Filters、工具注入、RAG 检索等
|
# 8.1 Payload 预处理:执行 Pipeline Filters、工具注入、RAG 检索等
|
||||||
|
# remark:并不涉及消息的持久化,只涉及发送给 LLM 前,上下文的封装
|
||||||
form_data, metadata, events = await process_chat_payload(
|
form_data, metadata, events = await process_chat_payload(
|
||||||
request, form_data, user, metadata, model
|
request, form_data, user, metadata, model
|
||||||
)
|
)
|
||||||
|
|
|
||||||
33
backend/open_webui/memory/mem0.py
Normal file
33
backend/open_webui/memory/mem0.py
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
|
||||||
|
async def mem0_search(user_id: str, chat_id: str, last_message: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
预留的 Mem0 检索接口:当前为占位实现。
|
||||||
|
未来可替换为实际检索逻辑,返回若干相关记忆条目(字符串)。
|
||||||
|
增加 chat_id 便于按会话窗口区分/隔离记忆。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# TODO: 接入真实 Mem0 检索
|
||||||
|
print("mem0_search")
|
||||||
|
print("user_id:", user_id)
|
||||||
|
print("chat_id:", chat_id)
|
||||||
|
print("last_message:", last_message)
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Mem0 search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def mem0_delete(user_id: str, chat_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除指定用户在指定 chat 窗口下的所有 Mem0 相关记忆(占位实现)。
|
||||||
|
未来可替换为实际删除逻辑。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# TODO: 接入真实删除逻辑(如按 chat_id 过滤)
|
||||||
|
print("mem0_delete")
|
||||||
|
print("user_id:", user_id)
|
||||||
|
print("chat_id:", chat_id)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Mem0 delete failed: {e}")
|
||||||
|
return False
|
||||||
|
|
@ -19,6 +19,7 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from open_webui.memory.mem0 import mem0_delete
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
|
@ -655,6 +656,9 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
chat = Chats.get_chat_by_id(id)
|
chat = Chats.get_chat_by_id(id)
|
||||||
|
|
||||||
|
# 清理该聊天的 Mem0 记忆条目
|
||||||
|
await mem0_delete(chat.user_id, id)
|
||||||
|
|
||||||
# 清理孤立标签(仅被该聊天使用的标签)
|
# 清理孤立标签(仅被该聊天使用的标签)
|
||||||
for tag in chat.meta.get("tags", []):
|
for tag in chat.meta.get("tags", []):
|
||||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||||
|
|
@ -676,6 +680,9 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||||
|
|
||||||
chat = Chats.get_chat_by_id(id)
|
chat = Chats.get_chat_by_id(id)
|
||||||
|
|
||||||
|
# 清理该聊天的 Mem0 记忆条目
|
||||||
|
await mem0_delete(user.id, id)
|
||||||
|
|
||||||
# 清理孤立标签
|
# 清理孤立标签
|
||||||
for tag in chat.meta.get("tags", []):
|
for tag in chat.meta.get("tags", []):
|
||||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,8 @@ from open_webui.routers.pipelines import (
|
||||||
process_pipeline_inlet_filter,
|
process_pipeline_inlet_filter,
|
||||||
process_pipeline_outlet_filter,
|
process_pipeline_outlet_filter,
|
||||||
)
|
)
|
||||||
from open_webui.routers.memories import query_memory, QueryMemoryForm
|
from open_webui.models.memories import Memories
|
||||||
|
from open_webui.memory.mem0 import mem0_search
|
||||||
|
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.files import (
|
from open_webui.utils.files import (
|
||||||
|
|
@ -517,37 +518,44 @@ async def chat_completion_tools_handler(
|
||||||
|
|
||||||
|
|
||||||
async def chat_memory_handler(
|
async def chat_memory_handler(
|
||||||
request: Request, form_data: dict, extra_params: dict, user
|
request: Request, form_data: dict, extra_params: dict, user, metadata
|
||||||
):
|
):
|
||||||
try:
|
"""
|
||||||
results = await query_memory(
|
聊天记忆处理器 - 注入用户手动保存的记忆 + Mem0 检索结果到当前对话上下文
|
||||||
request,
|
|
||||||
QueryMemoryForm(
|
新增行为:
|
||||||
**{
|
1. memory 特性开启时,直接注入用户所有记忆条目(不再做 RAG Top-K)
|
||||||
"content": get_last_user_message(form_data["messages"]) or "",
|
2. 预留 Mem0 检索:使用最后一条用户消息查询 Mem0,返回的条目也一并注入
|
||||||
"k": 3,
|
3. 上下文注入方式保持不变:统一写入 System Message
|
||||||
}
|
"""
|
||||||
),
|
user_message = get_last_user_message(form_data.get("messages", [])) or ""
|
||||||
user,
|
|
||||||
)
|
# === 1. 获取用户全部记忆(不再截断 Top-K) ===
|
||||||
except Exception as e:
|
memories = Memories.get_memories_by_user_id(user.id) or []
|
||||||
log.debug(e)
|
|
||||||
results = None
|
# === 2. 预留的 Mem0 检索结果 ===
|
||||||
|
mem0_results = await mem0_search(user.id, metadata.get("chat_id"), user_message)
|
||||||
|
|
||||||
|
# === 3. 格式化记忆条目 ===
|
||||||
|
entries = []
|
||||||
|
|
||||||
|
# 3.1 用户记忆库全量
|
||||||
|
for mem in memories:
|
||||||
|
created_at_date = time.strftime("%Y-%m-%d", time.localtime(mem.created_at)) if mem.created_at else "Unknown Date"
|
||||||
|
entries.append(f"[{created_at_date}] {mem.content}")
|
||||||
|
|
||||||
|
# 3.2 Mem0 检索结果
|
||||||
|
for item in mem0_results:
|
||||||
|
entries.append(f"[Mem0] {item}")
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
return form_data
|
||||||
|
|
||||||
user_context = ""
|
user_context = ""
|
||||||
if results and hasattr(results, "documents"):
|
for idx, entry in enumerate(entries):
|
||||||
if results.documents and len(results.documents) > 0:
|
user_context += f"{idx + 1}. {entry}\n"
|
||||||
for doc_idx, doc in enumerate(results.documents[0]):
|
|
||||||
created_at_date = "Unknown Date"
|
|
||||||
|
|
||||||
if results.metadatas[0][doc_idx].get("created_at"):
|
|
||||||
created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
|
|
||||||
created_at_date = time.strftime(
|
|
||||||
"%Y-%m-%d", time.localtime(created_at_timestamp)
|
|
||||||
)
|
|
||||||
|
|
||||||
user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
|
|
||||||
|
|
||||||
|
# === 4. 将记忆注入到系统消息中 ===
|
||||||
form_data["messages"] = add_or_update_system_message(
|
form_data["messages"] = add_or_update_system_message(
|
||||||
f"User Context:\n{user_context}\n", form_data["messages"], append=True
|
f"User Context:\n{user_context}\n", form_data["messages"], append=True
|
||||||
)
|
)
|
||||||
|
|
@ -1189,7 +1197,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
# 12.1 记忆功能 - 注入历史对话记忆
|
# 12.1 记忆功能 - 注入历史对话记忆
|
||||||
if "memory" in features and features["memory"]:
|
if "memory" in features and features["memory"]:
|
||||||
form_data = await chat_memory_handler(
|
form_data = await chat_memory_handler(
|
||||||
request, form_data, extra_params, user
|
request, form_data, extra_params, user, metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
# 12.2 网页搜索功能 - 执行网络搜索
|
# 12.2 网页搜索功能 - 执行网络搜索
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue