From c95add4ca40ff3ac2457bbeed19b89181ccaa0a0 Mon Sep 17 00:00:00 2001 From: xinyan Date: Sat, 29 Nov 2025 15:15:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0mem0=E7=9A=84search=5Fand=5Fa?= =?UTF-8?q?dd=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/open_webui/memory/mem0.py | 30 ++++++++++++++++++++++++-- backend/open_webui/utils/middleware.py | 4 ++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/backend/open_webui/memory/mem0.py b/backend/open_webui/memory/mem0.py index 64825781ba..23dd9c032a 100644 --- a/backend/open_webui/memory/mem0.py +++ b/backend/open_webui/memory/mem0.py @@ -2,6 +2,8 @@ from mem0 import MemoryClient import os from logging import getLogger +from typing import List, Dict, Optional + log = getLogger(__name__) mem0_api_key = os.getenv("MEM0_API_KEY") @@ -9,7 +11,6 @@ memory_client = MemoryClient(api_key=mem0_api_key) async def mem0_search(user_id: str, chat_id: str, last_message: str) -> list[str]: """ - 预留的 Mem0 检索接口:当前为占位实现。 未来可替换为实际检索逻辑,返回若干相关记忆条目(字符串)。 增加 chat_id 便于按会话窗口区分/隔离记忆。 """ @@ -27,7 +28,32 @@ async def mem0_search(user_id: str, chat_id: str, last_message: str) -> list[str log.debug(f"Mem0 search failed: {e}") return [] - +async def mem0_search_and_add(user_id: str, chat_id: str, last_message: str) -> list[str]: + """ + 检索并添加记忆,添加记忆使用mem0 的add功能,返回若干相关记忆条目(字符串)。 + 增加 chat_id 便于按会话窗口区分/隔离记忆。 + """ + try: + # TODO: 接入真实 Mem0 检索 + log.info(f"mem0_search called with user_id: {user_id}, chat_id: {chat_id}, last_message: {last_message}") + serach_rst = memory_client.search( + query=last_message, + filters={"user_id": user_id} + ) + if "results" not in serach_rst: + log.info("mem0_search_and_add no results found, skipping add") + memories=[] + else: + log.info(f"mem0_search_and_add found {len(serach_rst['results'])} results") + memories=[item.get("memory", item.get("text", "")) for item in serach_rst["results"]] + added_messages= [{"role": "user", "content": last_message}] + memory_client.add(added_messages, user_id=user_id,enable_graph=True,async_mode=False, metadata={"session_id": chat_id}) + log.info(f"mem0_add added message for user_id: {user_id}") + return memories + except Exception as e: + log.debug(f"Mem0 search and add failed: {e}") + return [] + async def mem0_delete(user_id: str, chat_id: str) -> bool: """ 删除指定用户在指定 chat 窗口下的所有 Mem0 相关记忆(占位实现)。 diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 44e5313bbf..f32bb5a5e3 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -55,7 +55,7 @@ from open_webui.routers.pipelines import ( process_pipeline_outlet_filter, ) from open_webui.models.memories import Memories -from open_webui.memory.mem0 import mem0_search +from open_webui.memory.mem0 import mem0_search, mem0_search_and_add from open_webui.utils.webhook import post_webhook from open_webui.utils.files import ( @@ -534,7 +534,7 @@ async def chat_memory_handler( memories = Memories.get_memories_by_user_id(user.id) or [] # === 2. 预留的 Mem0 检索结果 === - mem0_results = await mem0_search(user.id, metadata.get("chat_id"), user_message) + mem0_results = await mem0_search_and_add(user.id, metadata.get("chat_id"), user_message) # === 3. 格式化记忆条目 === entries = []