mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 05:45:19 +00:00
添加mem0的search_and_add实现
This commit is contained in:
parent
043d6cb56b
commit
c95add4ca4
2 changed files with 30 additions and 4 deletions
|
|
@ -2,6 +2,8 @@
|
||||||
from mem0 import MemoryClient
|
from mem0 import MemoryClient
|
||||||
import os
|
import os
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
log = getLogger(__name__)
|
log = getLogger(__name__)
|
||||||
|
|
||||||
mem0_api_key = os.getenv("MEM0_API_KEY")
|
mem0_api_key = os.getenv("MEM0_API_KEY")
|
||||||
|
|
@ -9,7 +11,6 @@ memory_client = MemoryClient(api_key=mem0_api_key)
|
||||||
|
|
||||||
async def mem0_search(user_id: str, chat_id: str, last_message: str) -> list[str]:
|
async def mem0_search(user_id: str, chat_id: str, last_message: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
预留的 Mem0 检索接口:当前为占位实现。
|
|
||||||
未来可替换为实际检索逻辑,返回若干相关记忆条目(字符串)。
|
未来可替换为实际检索逻辑,返回若干相关记忆条目(字符串)。
|
||||||
增加 chat_id 便于按会话窗口区分/隔离记忆。
|
增加 chat_id 便于按会话窗口区分/隔离记忆。
|
||||||
"""
|
"""
|
||||||
|
|
@ -27,6 +28,31 @@ async def mem0_search(user_id: str, chat_id: str, last_message: str) -> list[str
|
||||||
log.debug(f"Mem0 search failed: {e}")
|
log.debug(f"Mem0 search failed: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def mem0_search_and_add(user_id: str, chat_id: str, last_message: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
检索并添加记忆,添加记忆使用mem0 的add功能,返回若干相关记忆条目(字符串)。
|
||||||
|
增加 chat_id 便于按会话窗口区分/隔离记忆。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# TODO: 接入真实 Mem0 检索
|
||||||
|
log.info(f"mem0_search called with user_id: {user_id}, chat_id: {chat_id}, last_message: {last_message}")
|
||||||
|
serach_rst = memory_client.search(
|
||||||
|
query=last_message,
|
||||||
|
filters={"user_id": user_id}
|
||||||
|
)
|
||||||
|
if "results" not in serach_rst:
|
||||||
|
log.info("mem0_search_and_add no results found, skipping add")
|
||||||
|
memories=[]
|
||||||
|
else:
|
||||||
|
log.info(f"mem0_search_and_add found {len(serach_rst['results'])} results")
|
||||||
|
memories=[item.get("memory", item.get("text", "")) for item in serach_rst["results"]]
|
||||||
|
added_messages= [{"role": "user", "content": last_message}]
|
||||||
|
memory_client.add(added_messages, user_id=user_id,enable_graph=True,async_mode=False, metadata={"session_id": chat_id})
|
||||||
|
log.info(f"mem0_add added message for user_id: {user_id}")
|
||||||
|
return memories
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Mem0 search and add failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
async def mem0_delete(user_id: str, chat_id: str) -> bool:
|
async def mem0_delete(user_id: str, chat_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ from open_webui.routers.pipelines import (
|
||||||
process_pipeline_outlet_filter,
|
process_pipeline_outlet_filter,
|
||||||
)
|
)
|
||||||
from open_webui.models.memories import Memories
|
from open_webui.models.memories import Memories
|
||||||
from open_webui.memory.mem0 import mem0_search
|
from open_webui.memory.mem0 import mem0_search, mem0_search_and_add
|
||||||
|
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.files import (
|
from open_webui.utils.files import (
|
||||||
|
|
@ -534,7 +534,7 @@ async def chat_memory_handler(
|
||||||
memories = Memories.get_memories_by_user_id(user.id) or []
|
memories = Memories.get_memories_by_user_id(user.id) or []
|
||||||
|
|
||||||
# === 2. 预留的 Mem0 检索结果 ===
|
# === 2. 预留的 Mem0 检索结果 ===
|
||||||
mem0_results = await mem0_search(user.id, metadata.get("chat_id"), user_message)
|
mem0_results = await mem0_search_and_add(user.id, metadata.get("chat_id"), user_message)
|
||||||
|
|
||||||
# === 3. 格式化记忆条目 ===
|
# === 3. 格式化记忆条目 ===
|
||||||
entries = []
|
entries = []
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue