为后端若干对话补全相关函数注释;加入 backend/open_webui/memory/cross_window_memory.py:last_process_payload 接口

This commit is contained in:
Gaofeng 2025-11-25 17:24:31 +08:00
parent c47df68d56
commit 984bb5888e
7 changed files with 610 additions and 147 deletions

View file

@ -469,6 +469,7 @@ from open_webui.utils.chat import (
chat_completed as chat_completed_handler,
chat_action as chat_action_handler,
)
from open_webui.utils.misc import get_message_list
from open_webui.utils.embeddings import generate_embeddings
from open_webui.utils.middleware import process_chat_payload, process_chat_response
from open_webui.utils.access_control import has_access
@ -1592,7 +1593,7 @@ async def chat_completion(
)
# 8.2 调用 LLM 完成对话 (核心)
response = await chat_completion_handler(request, form_data, user)
response = await chat_completion_handler(request, form_data, user, chatting_completion = True)
# 8.3 更新数据库:保存模型 ID 到消息记录
if metadata.get("chat_id") and metadata.get("message_id"):
@ -1688,6 +1689,74 @@ generate_chat_completions = chat_completion
generate_chat_completion = chat_completion
@app.post("/api/chat/tutorial")
async def chat_completion_tutorial(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
"""
最小可正常工作的示例复用核心聊天链路支持直连/历史消息/流式返回
- 入参与 /api/chat/completions 基本兼容但不做额外后台任务
- 若提供 chat_id/message_id会读取历史消息并可触发 WS 事件否则直接返回 SSE
"""
# 直连兼容
model_item = form_data.pop("model_item", {})
if model_item.get("direct"):
request.state.direct = True
request.state.model = model_item
# 确保模型可用
if not request.app.state.MODELS:
await get_all_models(request, user=user)
model_id = form_data.get("model") or next(iter(request.app.state.MODELS.keys()))
if model_id not in request.app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Model not found"
)
model = request.app.state.MODELS[model_id]
if user.role == "user":
check_model_access(user, model)
# 取消息历史:优先用传入 messages否则按 chat_id 取全量
messages = form_data.get("messages") or []
chat_id = form_data.get("chat_id")
message_id = form_data.get("id") or form_data.get("message_id")
if not messages and chat_id:
messages_map = Chats.get_messages_map_by_chat_id(chat_id)
if messages_map:
if message_id and message_id in messages_map:
messages = get_message_list(messages_map, message_id)
else:
messages = list(messages_map.values())
# 准备 metadata便于复用后续事件/DB 逻辑
metadata = {
"user_id": user.id,
"chat_id": chat_id,
"message_id": message_id,
"session_id": form_data.get("session_id"),
"model": model,
"direct": model_item.get("direct", False),
"params": {},
}
request.state.metadata = metadata
form_data["metadata"] = metadata
# 核心链路:预处理 -> 调模型 -> 流式/非流式响应处理
form_data["model"] = model_id
form_data["messages"] = messages
form_data["stream"] = True # 强制流式,便于示例
form_data, metadata, events = await process_chat_payload(
request, form_data, user, metadata, model
)
response = await chat_completion_handler(request, form_data, user)
return await process_chat_response(
request, response, form_data, user, metadata, model, events, tasks=None
)
@app.post("/api/chat/completed")
async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)

View file

View file

@ -0,0 +1,21 @@
from typing import List, Dict
def last_process_payload(
user_id: str,
session_id: str,
messages: List[Dict],
):
"""
对调用 LLM API 前的上下文信息 进行加工
Args:
user_id (str): 用户的唯一 ID
session_id (str): 该用户本次对话/会话的 ID
messages (List[Dict]): 该用户在该对话下的聊天消息列表
形如 {"role": "system|user|assistant", "content": "...", "timestamp": 0}
"""
print("user_id:", user_id)
print("session_id:", session_id)
print("messages:", messages)

View file

@ -34,6 +34,8 @@ from open_webui.env import (
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.models.users import UserModel
from open_webui.memory.cross_window_memory import last_process_payload
from open_webui.utils.misc import extract_timestamped_messages
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
@ -807,61 +809,123 @@ async def generate_chat_completion(
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
chatting_completion: bool = False
):
"""
OpenAI 兼容的聊天完成端点 - 直接转发请求到 OpenAI API 或兼容服务
这是 OpenAI router 中的底层 API 调用函数负责
1. 应用模型配置base_model_id, system prompt, 参数覆盖
2. 验证用户权限模型访问控制
3. 处理 Azure OpenAI 特殊格式转换
4. 处理推理模型reasoning model特殊逻辑
5. 转发 HTTP 请求到上游 API支持流式和非流式
Args:
request: FastAPI Request 对象
form_data: OpenAI 格式的聊天请求
- model: 模型 ID
- messages: 消息列表
- stream: 是否流式响应
- temperature, max_tokens 等参数
user: 已验证的用户对象
bypass_filter: 是否绕过权限检查
Returns:
- 流式: StreamingResponse (SSE)
- 非流式: dict (OpenAI JSON 格式)
Raises:
HTTPException 403: 无权限访问模型
HTTPException 404: 模型不存在
HTTPException 500: 上游 API 连接失败
"""
# print("user:", user)
# user:
# id='55f85fb0-4aca-48bc-aea1-afce50ac989e'
# name='gaofeng1'
# email='h.summit1628935449@gmail.com'
# username=None
# role='user'
# profile_image_url='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAEa0lEQVR4Xu2aW0hUQRzG/7vr7noLSZGgIIMsobAgjRKCrhSJD0EXgpSCopcK6qmCHgyC6Kmo6CGMiAgqeqgQ6SG7EYhoQQpdSCKppBARMXUv7m47C7vr2ZX2XHa378S3b8uZmfPN95tvzpw54xhvK48IfzAOOAgEhkVMCIFg8SAQMB4EQiBoDoDp4TOEQMAcAJPDhBAImANgcpgQAgFzAEwOE0IgYA6AyWFCCATMATA5TAiBgDkAJocJIRAwB8DkMCEEAuYAmBwmhEDAHACTw4QQCJgDYHKYEAIBcwBMDhNCIGAOgMlhQggEzAEwOUwIgVh3wF3TIq55a8VZvlwchRXiKJmvaTQyMSQR34iEfnVLcOC+hIffWL9pnlqwVUK8DRekoKoxDUAmr2Jg+q/K9GBHpqL//LotgDgr66Ro/TVxlFWbNyw0JcFPd8TfddJ8G3moCQ9ETU/ehvMirqI0O8Ij/RIeH5TQcG/imsNTJs6yJeKqXDVrkoLv26ChQANR01PhxutaGNGRPv2tUwJ9lzM+G9wrjolnaXNasvyvj0fTcjsP4934LaCBlOzq1prpHxV/z1nDZhY1dUQXAWsS7kTGBmTiQfK/cdtyVwMWiHqAu5cdSvY8mgx/12nDMOINlOzt10xhvqctkA95WCClzQMi3rkJIIF3FyXQe8700PTUnxHPyhOJ+qEfz2TqyW7T7eWqIiSQVPOyNcWolMTfT0JDL5kQvaMqdc5HXxnp7ZeecpAJSZ2uJh9vzbii0tNZO5SBAxJb6m5JLknVNsjE3Vo7eJkVjXBA1LuDd3VronPq5W/y4YasdNYOjcADmf7aLr7O/XbwMisa/ysgpQdHDJny+0aFofL5KEwg+XDZwD3ggait86n2Rl1dYkJ02WSsUGx3d90lU/tOakHwt1/Bwu2aPS1OWTrZlB74rtnhzZZxhZtvScGipoSKbLWrs1u6isFNWUp18Y4X4qxIvntY3ceKO0EgusZEeqHUnd5s7WURiEkgqlrq9om/p1WCfVcstChCIBbsm+17iO/5YUs7tARiAYiqmvbF0MJHqhjgmn05WSxY7KamOuRDPa5QnTYp3nZP86FKXdN7rEfVd1fvEffinWltqHa4yjIxlNTurxrdqYfhVFNqJzg0/FbCY58lEhhLtO6qrBfnnCrNSk1za3VQ4ssj8b06YkJRbqtAJ2RmUrx1p8S1YJNlN/Smy/KNTDZgCyDxvqm0uGuPat62dfU7mojQzy4JfrhpaVGg614WC9kKyMzEqGeDenl0qrO9qScao8eFYt/ORz/GDtFZXS5b9NhQdVsCMdRDmxUmEDBgBEIgYA6AyWFCCATMATA5TAiBgDkAJocJIRAwB8DkMCEEAuYAmBwmhEDAHACTw4QQCJgDYHKYEAIBcwBMDhNCIGAOgMlhQggEzAEwOUwIgYA5ACaHCSEQMAfA5DAhBALmAJgcJoRAwBwAk8OEEAiYA2BymBACAXMATA4TQiBgDoDJYUIIBMwBMDl/AP79PANEXNbbAAAAAElFTkSuQmCC'
# bio=None
# gender=None
# date_of_birth=None
# info=None
# settings=UserSettings(ui={'memory': True})
# api_key=None
# oauth_sub=None
# last_active_at=1763997832
# updated_at=1763971141
# created_at=1763874812
# === 1. 权限检查配置 ===
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
idx = 0
idx = 0 # 用于标识使用哪个 OPENAI_API_BASE_URL
# === 2. 准备 Payload 和提取元数据 ===
payload = {**form_data}
metadata = payload.pop("metadata", None)
metadata = payload.pop("metadata", None) # 移除内部元数据,不发送给上游 API
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
# === 3. 应用模型配置和权限检查 ===
# Check model info and override the payload
if model_info:
# 3.1 如果配置了 base_model_id替换为底层模型 ID
# 例如:自定义模型 "my-gpt4" → 实际调用 "gpt-4-turbo"
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_id = model_info.base_model_id
# 3.2 应用模型参数temperature, max_tokens 等)
params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
system = params.pop("system", None) # 提取 system prompt
# 应用模型参数到 payload覆盖用户传入的参数
payload = apply_model_params_to_body_openai(params, payload)
# 注入或替换 system prompt
payload = apply_system_prompt_to_body(system, payload, metadata, user)
# 3.3 权限检查:验证用户是否有权限访问该模型
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if not (
user.id == model_info.user_id
user.id == model_info.user_id # 用户是模型创建者
or has_access(
user.id, type="read", access_control=model_info.access_control
)
) # 或用户在访问控制列表中
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
elif not bypass_filter:
# 如果模型信息不存在且未绕过过滤器,只有管理员可访问
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
await get_all_models(request, user=user)
# === 4. 查找 OpenAI API 配置 ===
await get_all_models(request, user=user) # 刷新模型列表
model = request.app.state.OPENAI_MODELS.get(model_id)
if model:
idx = model["urlIdx"]
idx = model["urlIdx"] # 获取 API 基础 URL 索引
else:
raise HTTPException(
status_code=404,
detail="Model not found",
)
# === 5. 获取 API 配置并处理 prefix_id ===
# Get the API config for the model
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(idx),
@ -870,10 +934,13 @@ async def generate_chat_completion(
), # Legacy support
)
# 移除模型 ID 前缀(如果配置了 prefix_id
# 例如:模型 ID "custom.gpt-4" → 发送给 API 的是 "gpt-4"
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
# === 6. Pipeline 模式:注入用户信息 ===
# Add user info to the payload if the model is a pipeline
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {
@ -886,32 +953,40 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
# === 7. 推理模型特殊处理 ===
# Check if model is a reasoning model that needs special handling
if is_openai_reasoning_model(payload["model"]):
# 推理模型(如 o1使用 max_completion_tokens 而非 max_tokens
payload = openai_reasoning_model_handler(payload)
elif "api.openai.com" not in url:
# 非 OpenAI 官方 API向后兼容将 max_completion_tokens 转为 max_tokens
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
payload["max_tokens"] = payload["max_completion_tokens"]
del payload["max_completion_tokens"]
# 避免同时存在 max_tokens 和 max_completion_tokens
if "max_tokens" in payload and "max_completion_tokens" in payload:
del payload["max_tokens"]
# === 8. 转换 logit_bias 格式 ===
# Convert the modified body back to JSON
if "logit_bias" in payload:
payload["logit_bias"] = json.loads(
convert_logit_bias_input_to_json(payload["logit_bias"])
)
# === 9. 准备请求头和 Cookies ===
headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, metadata, user=user
)
# === 10. Azure OpenAI 特殊处理 ===
if api_config.get("azure", False):
api_version = api_config.get("api_version", "2023-03-15-preview")
request_url, payload = convert_to_azure_payload(url, payload, api_version)
# 只有在非 Azure Entra ID 认证时才设置 api-key header
# Only set api-key header if not using Azure Entra ID authentication
auth_type = api_config.get("auth_type", "bearer")
if auth_type not in ("azure_ad", "microsoft_entra_id"):
@ -920,16 +995,34 @@ async def generate_chat_completion(
headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}"
else:
# 标准 OpenAI 兼容 API
request_url = f"{url}/chat/completions"
payload = json.dumps(payload)
if chatting_completion:
try:
# 可选钩子:在发送到上游前记录/审计 payload需自行实现 last_process_payload
log.debug(
f"chatting_completion hook user={user.id} chat_id={metadata.get('chat_id')} model={payload.get('model')}"
)
last_process_payload(
user_id = user.id,
session_id = metadata.get("chat_id"),
messages = extract_timestamped_messages(payload.get("messages", [])),
)
except Exception as e:
log.debug(f"chatting_completion 钩子执行失败: {e}")
payload = json.dumps(payload) # 序列化为 JSON 字符串
# === 11. 初始化请求状态变量 ===
r = None
session = None
streaming = False
response = None
try:
# === 12. 发起 HTTP 请求到上游 API ===
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
@ -943,8 +1036,10 @@ async def generate_chat_completion(
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
# === 13. 处理响应 ===
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
# 流式响应:直接转发 SSE 流
streaming = True
return StreamingResponse(
r.content,
@ -955,19 +1050,22 @@ async def generate_chat_completion(
),
)
else:
# 非流式响应:解析 JSON
try:
response = await r.json()
except Exception as e:
log.error(e)
response = await r.text()
response = await r.text() # 如果 JSON 解析失败,返回纯文本
# 处理错误响应
if r.status >= 400:
if isinstance(response, (dict, list)):
return JSONResponse(status_code=r.status, content=response)
else:
return PlainTextResponse(status_code=r.status, content=response)
return response
return response # 成功响应
except Exception as e:
log.exception(e)
@ -976,6 +1074,8 @@ async def generate_chat_completion(
detail="CyberLover: Server Connection Error",
)
finally:
# === 14. 清理资源 ===
# 非流式响应需要手动关闭连接(流式响应在 BackgroundTask 中处理)
if not streaming:
await cleanup_response(r, session)

View file

@ -167,124 +167,182 @@ async def generate_chat_completion(
form_data: dict,
user: Any,
bypass_filter: bool = False,
chatting_completion: bool = False
):
log.debug(f"generate_chat_completion: {form_data}")
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
"""
聊天完成生成函数 - 根据模型类型分发到不同的底层 API 处理器
这是聊天完成的核心路由函数负责
1. 验证模型存在性和用户权限
2. 处理 Direct 模式直连外部 API
3. 处理 Arena 模式随机选择模型进行对比
4. 根据模型类型分发到对应处理器
- Pipe: Pipeline 插件函数
- Ollama: Ollama 本地模型
- OpenAI: OpenAI 兼容 API ( Claude, Gemini )
Args:
request: FastAPI Request 对象
form_data: OpenAI 格式的聊天请求数据
user: 用户对象
bypass_filter: 是否绕过权限和 Pipeline Filter 检查
Returns:
- 流式: StreamingResponse (SSE 格式)
- 非流式: dict (OpenAI 兼容格式)
Raises:
Exception: 模型不存在或无权限访问
"""
log.debug(f"generate_chat_completion: {form_data}")
# === 1. 权限检查配置 ===
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True # 全局配置:绕过所有权限检查
# === 2. 合并元数据 ===
# 从 request.state.metadata 获取上游传递的元数据chat_id, user_id 等)
if hasattr(request.state, "metadata"):
if "metadata" not in form_data:
form_data["metadata"] = request.state.metadata
else:
# 合并request.state.metadata 优先级更高
form_data["metadata"] = {
**form_data["metadata"],
**request.state.metadata,
}
# === 3. 确定模型列表来源 ===
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
# Direct 模式:用户直接提供外部 API 配置(如 OpenAI API Key
models = {
request.state.model["id"]: request.state.model,
}
log.debug(f"direct connection to model: {models}")
else:
# 标准模式:使用平台内置模型列表
models = request.app.state.MODELS
# === 4. 验证模型存在性 ===
model_id = form_data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
# === 5. Direct 模式分支:直连外部 API ===
if getattr(request.state, "direct", False):
return await generate_direct_chat_completion(
request, form_data, user=user, models=models
)
else:
# Check if user has access to the model
# === 6. 标准模式:检查用户权限 ===
if not bypass_filter and user.role == "user":
try:
check_model_access(user, model)
check_model_access(user, model) # 验证 RBAC 权限
except Exception as e:
raise e
if model.get("owned_by") == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
# === 7. Arena 模式:随机选择模型进行盲测对比 ===
if False:
if model.get("owned_by") == "arena":
# 获取候选模型列表
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
# 如果是排除模式,反选模型列表
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
form_data["model"] = selected_model_id
# 随机选择一个模型
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
# 未指定则从所有非 Arena 模型中随机选择
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
if form_data.get("stream") == True:
# 替换模型 ID
form_data["model"] = selected_model_id
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
# 流式响应:在首个 chunk 中注入 selected_model_id
if form_data.get("stream") == True:
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
async def stream_wrapper(stream):
"""在流式响应前添加选中的模型 ID"""
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
# 递归调用自身,绕过 Arena 逻辑
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
# 非流式响应:直接在结果中添加 selected_model_id
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
# === 8. Pipeline 模式:调用自定义 Python 函数 ===
if False:
if model.get("pipe"):
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
if model.get("owned_by") == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
# === 9. Ollama 模式:调用本地 Ollama 服务 ===
if False:
if model.get("owned_by") == "ollama":
# 转换 OpenAI 格式 → Ollama 格式
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
# 流式响应:转换 Ollama SSE → OpenAI SSE
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
# 非流式响应:转换 Ollama JSON → OpenAI JSON
return convert_response_ollama_to_openai(response)
# === 10. OpenAI 兼容模式:调用 OpenAI API 或兼容服务 ===
# >>>
return await generate_openai_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
chatting_completion = chatting_completion
)
chat_completion = generate_chat_completion

View file

@ -994,25 +994,52 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, user, metadata, model):
# Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
# -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
# -> Chat Files
"""
处理聊天请求的 Payload - 执行 PipelineFilter功能增强和工具注入
这是聊天请求预处理的核心函数按以下顺序执行
1. Pipeline Inlet (管道入口) - 自定义 Python 插件预处理
2. Filter Inlet (过滤器入口) - 函数过滤器预处理
3. Chat Memory (记忆) - 注入历史对话记忆
4. Chat Web Search (网页搜索) - 执行网络搜索并注入结果
5. Chat Image Generation (图像生成) - 处理图像生成请求
6. Chat Code Interpreter (代码解释器) - 注入代码执行提示词
7. Chat Tools Function Calling (工具调用) - 处理函数/工具调用
8. Chat Files (文件处理) - 处理上传文件知识库文件RAG 检索
Args:
request: FastAPI Request 对象
form_data: OpenAI 格式的聊天请求数据
user: 用户对象
metadata: 元数据chat_id, message_id, tool_ids, files
model: 模型配置对象
Returns:
tuple: (form_data, metadata, events)
- form_data: 处理后的请求数据
- metadata: 更新后的元数据
- events: 需要发送给前端的事件列表如引用来源
"""
# === 1. 应用模型参数到请求 ===
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
# === 2. 处理 System Prompt 变量替换 ===
system_message = get_system_message(form_data.get("messages", []))
if system_message: # Chat Controls/User Settings
try:
# 替换 system prompt 中的变量(如 {{USER_NAME}}, {{CURRENT_DATE}}
form_data = apply_system_prompt_to_body(
system_message.get("content"), form_data, metadata, user, replace=True
) # Required to handle system prompt variables
)
except:
pass
event_emitter = get_event_emitter(metadata)
event_call = get_event_call(metadata)
# === 3. 初始化事件发射器和回调 ===
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
event_call = get_event_call(metadata) # 事件调用函数
# === 4. 获取 OAuth Token ===
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
@ -1023,9 +1050,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
# === 5. 构建额外参数(供 Pipeline/Filter/Tools 使用)===
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_call,
"__event_emitter__": event_emitter, # 用于向前端发送实时事件
"__event_call__": event_call, # 用于调用事件回调
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
@ -1033,15 +1061,19 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__oauth_token__": oauth_token,
}
# === 6. 确定模型列表和任务模型 ===
# Initialize events to store additional event to be sent to the client
# Initialize contexts and citation
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
# Direct 模式:使用用户直连的模型
models = {
request.state.model["id"]: request.state.model,
}
else:
# 标准模式:使用平台所有模型
models = request.app.state.MODELS
# 获取任务模型 ID用于工具调用、标题生成等后台任务
task_model_id = get_task_model_id(
form_data["model"],
request.app.state.config.TASK_MODEL,
@ -1049,10 +1081,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
models,
)
events = []
sources = []
# === 7. 初始化事件和引用来源列表 ===
events = [] # 需要发送给前端的事件(如 sources 引用)
sources = [] # RAG 检索到的文档来源
# Folder "Project" handling
# === 8. Folder "Project" 处理 - 注入文件夹的 System Prompt 和文件 ===
# Check if the request has chat_id and is inside of a folder
chat_id = metadata.get("chat_id", None)
if chat_id and user:
@ -1061,21 +1094,24 @@ async def process_chat_payload(request, form_data, user, metadata, model):
folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
if folder and folder.data:
# 注入文件夹的 system prompt
if "system_prompt" in folder.data:
form_data = apply_system_prompt_to_body(
folder.data["system_prompt"], form_data, metadata, user
)
# 注入文件夹关联的文件
if "files" in folder.data:
form_data["files"] = [
*folder.data["files"],
*form_data.get("files", []),
]
# Model "Knowledge" handling
# === 9. Model "Knowledge" 处理 - 注入模型绑定的知识库 ===
user_message = get_last_user_message(form_data["messages"])
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
if model_knowledge:
# 向前端发送知识库搜索状态
await event_emitter(
{
"type": "status",
@ -1089,6 +1125,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
knowledge_files = []
for item in model_knowledge:
# 处理旧格式的 collection_name
if item.get("collection_name"):
knowledge_files.append(
{
@ -1097,6 +1134,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"legacy": True,
}
)
# 处理新格式的 collection_names多个集合
elif item.get("collection_names"):
knowledge_files.append(
{
@ -1109,12 +1147,14 @@ async def process_chat_payload(request, form_data, user, metadata, model):
else:
knowledge_files.append(item)
# 合并模型知识库文件和用户上传文件
files = form_data.get("files", [])
files.extend(knowledge_files)
form_data["files"] = files
variables = form_data.pop("variables", None)
# === 10. Pipeline Inlet 处理 - 执行自定义 Python 插件 ===
# Process the form_data through the pipeline
try:
form_data = await process_pipeline_inlet_filter(
@ -1123,6 +1163,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
except Exception as e:
raise e
# === 11. Filter Inlet 处理 - 执行函数过滤器 ===
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
@ -1141,23 +1182,28 @@ async def process_chat_payload(request, form_data, user, metadata, model):
except Exception as e:
raise Exception(f"{e}")
# === 12. 功能增强处理 (Features) ===
features = form_data.pop("features", None)
if features:
# 12.1 记忆功能 - 注入历史对话记忆
if "memory" in features and features["memory"]:
form_data = await chat_memory_handler(
request, form_data, extra_params, user
)
# 12.2 网页搜索功能 - 执行网络搜索
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
)
# 12.3 图像生成功能 - 处理图像生成请求
if "image_generation" in features and features["image_generation"]:
form_data = await chat_image_generation_handler(
request, form_data, extra_params, user
)
# 12.4 代码解释器功能 - 注入代码执行提示词
if "code_interpreter" in features and features["code_interpreter"]:
form_data["messages"] = add_or_update_user_message(
(
@ -1168,33 +1214,33 @@ async def process_chat_payload(request, form_data, user, metadata, model):
form_data["messages"],
)
# === 13. 提取工具和文件信息 ===
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# === 14. 文件夹类型文件展开 ===
prompt = get_last_user_message(form_data["messages"])
# TODO: re-enable URL extraction from prompt
# urls = []
# if prompt and len(prompt or "") < 500 and (not files or len(files) == 0):
# urls = extract_urls(prompt)
if files:
if not files:
files = []
for file_item in files:
# 如果文件类型是 folder展开其包含的所有文件
if file_item.get("type", "file") == "folder":
# Get folder files
folder_id = file_item.get("id", None)
if folder_id:
folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id)
if folder and folder.data and "files" in folder.data:
# 移除文件夹项,添加文件夹内的文件
files = [f for f in files if f.get("id", None) != folder_id]
files = [*files, *folder.data["files"]]
# files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
# Remove duplicate files based on their content
# 去重文件(基于文件内容)
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
# === 15. 更新元数据 ===
metadata = {
**metadata,
"tool_ids": tool_ids,
@ -1202,25 +1248,28 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}
form_data["metadata"] = metadata
# === 16. 准备工具字典 ===
# Server side tools
tool_ids = metadata.get("tool_ids", None)
tool_ids = metadata.get("tool_ids", None) # 服务器端工具 ID
# Client side tools
direct_tool_servers = metadata.get("tool_servers", None)
direct_tool_servers = metadata.get("tool_servers", None) # 客户端直连工具服务器
log.debug(f"{tool_ids=}")
log.debug(f"{direct_tool_servers=}")
tools_dict = {}
tools_dict = {} # 所有工具的字典
mcp_clients = {}
mcp_tools_dict = {}
mcp_clients = {} # MCP (Model Context Protocol) 客户端
mcp_tools_dict = {} # MCP 工具字典
if tool_ids:
for tool_id in tool_ids:
# === 16.1 处理 MCP (Model Context Protocol) 工具 ===
if tool_id.startswith("server:mcp:"):
try:
server_id = tool_id[len("server:mcp:") :]
# 查找 MCP 服务器连接配置
mcp_server_connection = None
for (
server_connection
@ -1236,6 +1285,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.error(f"MCP server with id {server_id} not found")
continue
# 处理认证类型
auth_type = mcp_server_connection.get("auth_type", "")
headers = {}
@ -1244,7 +1294,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
# 无需认证
pass
elif auth_type == "session":
headers["Authorization"] = (
@ -1273,16 +1323,19 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.error(f"Error getting OAuth token: {e}")
oauth_token = None
# 连接到 MCP 服务器
mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
# 获取 MCP 工具列表并注册
tool_specs = await mcp_clients[server_id].list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(client, function_name):
"""为每个 MCP 工具创建异步调用函数"""
async def tool_function(**kwargs):
return await client.call_tool(
function_name,
@ -1295,6 +1348,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
mcp_clients[server_id], tool_spec["name"]
)
# 注册 MCP 工具到工具字典
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
"spec": {
**tool_spec,
@ -1309,6 +1363,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.debug(e)
continue
# === 16.2 获取标准工具Function Tools===
tools_dict = await get_tools(
request,
tool_ids,
@ -1320,9 +1375,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
# 合并 MCP 工具
if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict}
# === 16.3 处理客户端直连工具服务器 ===
if direct_tool_servers:
for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", [])
@ -1334,19 +1391,21 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"server": tool_server,
}
# 保存 MCP 客户端到元数据(用于最后清理)
if mcp_clients:
metadata["mcp_clients"] = mcp_clients
# === 17. 工具调用处理 ===
if tools_dict:
if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
# 原生函数调用模式:直接传递给 LLM
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
# 默认模式:通过 Prompt 实现工具调用
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, extra_params, user, models, tools_dict
@ -1355,6 +1414,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
except Exception as e:
log.exception(e)
# === 18. 文件处理 - RAG 检索 ===
try:
form_data, flags = await chat_completion_files_handler(
request, form_data, extra_params, user
@ -1363,11 +1423,13 @@ async def process_chat_payload(request, form_data, user, metadata, model):
except Exception as e:
log.exception(e)
# === 19. 构建上下文字符串并注入到消息 ===
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
citation_idx_map = {}
citation_idx_map = {} # 引用索引映射(文档 ID → 引用编号)
# 遍历所有来源,构建上下文字符串
for source in sources:
if "document" in source:
for document_text, document_metadata in zip(
@ -1380,9 +1442,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
or "N/A"
)
# 为每个来源分配唯一的引用编号
if source_id not in citation_idx_map:
citation_idx_map[source_id] = len(citation_idx_map) + 1
# 构建 XML 格式的来源标签
context_string += (
f'<source id="{citation_idx_map[source_id]}"'
+ (f' name="{source_name}"' if source_name else "")
@ -1393,6 +1457,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
if prompt is None:
raise Exception("No user message found")
# 使用 RAG 模板将上下文注入到用户消息中
if context_string != "":
form_data["messages"] = add_or_update_user_message(
rag_template(
@ -1404,6 +1469,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
append=False,
)
# === 20. 整理引用来源并添加到事件 ===
# If there are citations, add them to the data_items
sources = [
source
@ -1415,6 +1481,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
if len(sources) > 0:
events.append({"sources": sources})
# === 21. 完成知识库搜索状态 ===
if model_knowledge:
await event_emitter(
{
@ -1434,29 +1501,63 @@ async def process_chat_payload(request, form_data, user, metadata, model):
async def process_chat_response(
request, response, form_data, user, metadata, model, events, tasks
):
"""
处理聊天响应 - 规范化/分发聊天完成响应
这是聊天响应后处理的核心函数负责
1. 处理流式SSE/ndjson和非流式JSON响应
2. 通过 WebSocket 发送事件到前端
3. 持久化消息/错误/元数据到数据库
4. 触发后台任务标题生成标签生成Follow-ups
5. 流式响应时消费上游数据块重建内容处理工具调用转发增量
Args:
request: FastAPI Request 对象
response: 上游响应dict/JSONResponse/StreamingResponse
form_data: 原始请求 payload可能被上游修改
user: 已验证的用户对象
metadata: 早期收集的聊天/会话上下文
model: 解析后的模型配置
events: 需要发送的额外事件 sources 引用
tasks: 可选的后台任务title/tags/follow-ups
Returns:
- 非流式: dict (OpenAI JSON 格式)
- 流式: StreamingResponse (SSE/ndjson 格式)
"""
# === 内部函数:后台任务处理器 ===
async def background_tasks_handler():
"""
在响应完成后执行后台任务
1. Follow-ups 生成 - 生成后续问题建议
2. Title 生成 - 自动生成聊天标题
3. Tags 生成 - 自动生成聊天标签
"""
message = None
messages = []
# 获取消息历史
if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"):
# 从数据库获取持久化的聊天历史
messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
message = messages_map.get(metadata["message_id"]) if messages_map else None
message_list = get_message_list(messages_map, metadata["message_id"])
# Remove details tags and files from the messages.
# as get_message_list creates a new list, it does not affect
# the original messages outside of this handler
# 清理消息内容:移除 details 标签和文件
# get_message_list 创建新列表,不影响原始消息
messages = []
for message in message_list:
content = message.get("content", "")
# 处理多模态内容(图片 + 文本)
if isinstance(content, list):
for item in content:
if item.get("type") == "text":
content = item["text"]
break
# 移除 <details> 标签和 Markdown 图片
if isinstance(content, str):
content = re.sub(
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
@ -1468,21 +1569,21 @@ async def process_chat_response(
messages.append(
{
**message,
"role": message.get(
"role", "assistant"
), # Safe fallback for missing role
"role": message.get("role", "assistant"), # 安全回退
"content": content,
}
)
else:
# Local temp chat, get the model and message from the form_data
# 临时聊天local:):从 form_data 获取
message = get_last_user_message_item(form_data.get("messages", []))
messages = form_data.get("messages", [])
if message:
message["model"] = form_data.get("model")
# 执行后台任务
if message and "model" in message:
if tasks and messages:
# === 任务 1: Follow-ups 生成 ===
if (
TASKS.FOLLOW_UP_GENERATION in tasks
and tasks[TASKS.FOLLOW_UP_GENERATION]
@ -1510,6 +1611,7 @@ async def process_chat_response(
else:
follow_ups_string = ""
# 提取 JSON 对象(从第一个 { 到最后一个 }
follow_ups_string = follow_ups_string[
follow_ups_string.find("{") : follow_ups_string.rfind("}")
+ 1
@ -1519,6 +1621,7 @@ async def process_chat_response(
follow_ups = json.loads(follow_ups_string).get(
"follow_ups", []
)
# 通过 WebSocket 发送 Follow-ups
await event_emitter(
{
"type": "chat:message:follow_ups",
@ -1528,6 +1631,7 @@ async def process_chat_response(
}
)
# 持久化到数据库
if not metadata.get("chat_id", "").startswith("local:"):
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
@ -1540,9 +1644,9 @@ async def process_chat_response(
except Exception as e:
pass
if not metadata.get("chat_id", "").startswith(
"local:"
): # Only update titles and tags for non-temp chats
# === 任务 2 & 3: 标题和标签生成(仅非临时聊天)===
if not metadata.get("chat_id", "").startswith("local:"):
# 任务 2: 标题生成
if (
TASKS.TITLE_GENERATION in tasks
and tasks[TASKS.TITLE_GENERATION]
@ -1579,6 +1683,7 @@ async def process_chat_response(
else:
title_string = ""
# 提取 JSON 对象
title_string = title_string[
title_string.find("{") : title_string.rfind("}") + 1
]
@ -1593,16 +1698,19 @@ async def process_chat_response(
if not title:
title = messages[0].get("content", user_message)
# 更新数据库
Chats.update_chat_title_by_id(
metadata["chat_id"], title
)
# 通过 WebSocket 发送标题
await event_emitter(
{
"type": "chat:title",
"data": title,
}
)
# 如果只有 2 条消息(首次对话),直接用用户消息作为标题
elif len(messages) == 2:
title = messages[0].get("content", user_message)
@ -1615,6 +1723,7 @@ async def process_chat_response(
}
)
# 任务 3: 标签生成
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
res = await generate_chat_tags(
request,
@ -1638,16 +1747,19 @@ async def process_chat_response(
else:
tags_string = ""
# 提取 JSON 对象
tags_string = tags_string[
tags_string.find("{") : tags_string.rfind("}") + 1
]
try:
tags = json.loads(tags_string).get("tags", [])
# 更新数据库
Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user
)
# 通过 WebSocket 发送标签
await event_emitter(
{
"type": "chat:tags",
@ -1657,6 +1769,7 @@ async def process_chat_response(
except Exception as e:
pass
# === 1. 解析事件发射器(仅在有实时会话时使用)===
event_emitter = None
event_caller = None
if (
@ -1667,18 +1780,19 @@ async def process_chat_response(
and "message_id" in metadata
and metadata["message_id"]
):
event_emitter = get_event_emitter(metadata)
event_caller = get_event_call(metadata)
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
event_caller = get_event_call(metadata) # 事件调用器
# Non-streaming response
# === 2. 非流式响应处理 ===
if not isinstance(response, StreamingResponse):
if event_emitter:
try:
if isinstance(response, dict) or isinstance(response, JSONResponse):
# 处理单项列表(解包)
if isinstance(response, list) and len(response) == 1:
# If the response is a single-item list, unwrap it #17213
response = response[0]
# 解析 JSONResponse 的 body
if isinstance(response, JSONResponse) and isinstance(
response.body, bytes
):
@ -1693,6 +1807,7 @@ async def process_chat_response(
else:
response_data = response
# 处理错误响应
if "error" in response_data:
error = response_data.get("error")
@ -1701,6 +1816,7 @@ async def process_chat_response(
else:
error = str(error)
# 保存错误到数据库
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
@ -1708,6 +1824,7 @@ async def process_chat_response(
"error": {"content": error},
},
)
# 通过 WebSocket 发送错误事件
if isinstance(error, str) or isinstance(error, dict):
await event_emitter(
{
@ -1821,7 +1938,7 @@ async def process_chat_response(
return response
# Non standard response
# Non standard response (not SSE/ndjson)
if not any(
content_type in response.headers["Content-Type"]
for content_type in ["text/event-stream", "application/x-ndjson"]
@ -1854,7 +1971,7 @@ async def process_chat_response(
)
]
# Streaming response
# Streaming response: consume upstream SSE/ndjson, forward deltas/events, handle tools
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
model_id = form_data.get("model", "")
@ -2289,26 +2406,49 @@ async def process_chat_response(
)
async def stream_body_handler(response, form_data):
nonlocal content
nonlocal content_blocks
"""
流式响应体处理器 - 消费上游 SSE 流并处理内容块
response_tool_calls = []
核心功能
1. 解析 SSE (Server-Sent Events)
2. 处理工具调用tool_calls增量更新
3. 处理推理内容reasoning_content- <think> 标签
4. 处理解决方案solution和代码解释器code_interpreter标签
5. 实时保存消息到数据库可选
6. 控制流式推送频率delta throttling避免 WebSocket 过载
delta_count = 0
Args:
response: 上游 StreamingResponse 对象
form_data: 原始请求数据
"""
nonlocal content # 累积的完整文本内容
nonlocal content_blocks # 内容块列表text/reasoning/code_interpreter
response_tool_calls = [] # 累积的工具调用列表
# === 1. 初始化流式推送控制 ===
delta_count = 0 # 当前累积的 delta 数量
delta_chunk_size = max(
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, # 全局配置的块大小
int(
metadata.get("params", {}).get("stream_delta_chunk_size")
or 1
),
), # 用户配置的块大小
)
last_delta_data = None
last_delta_data = None # 待发送的最后一个 delta 数据
async def flush_pending_delta_data(threshold: int = 0):
"""
刷新待发送的 delta 数据
Args:
threshold: 阈值 delta_count >= threshold 时才发送
"""
nonlocal delta_count
nonlocal last_delta_data
if delta_count >= threshold and last_delta_data:
# 通过 WebSocket 发送累积的 delta
await event_emitter(
{
"type": "chat:completion",
@ -2318,7 +2458,9 @@ async def process_chat_response(
delta_count = 0
last_delta_data = None
# === 2. 消费 SSE 流 ===
async for line in response.body_iterator:
# 解码字节流为字符串
line = (
line.decode("utf-8", "replace")
if isinstance(line, bytes)
@ -2326,20 +2468,22 @@ async def process_chat_response(
)
data = line
# Skip empty lines
# 跳过空行
if not data.strip():
continue
# "data:" is the prefix for each event
# SSE 格式:每个事件以 "data:" 开头
if not data.startswith("data:"):
continue
# Remove the prefix
# 移除 "data:" 前缀
data = data[len("data:") :].strip()
try:
# 解析 JSON 数据
data = json.loads(data)
# === 3. 执行 Filter 函数stream 类型)===
data, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
@ -2349,11 +2493,14 @@ async def process_chat_response(
)
if data:
# 处理自定义事件
if "event" in data:
await event_emitter(data.get("event", {}))
# === 4. 处理 Arena 模式的模型选择 ===
if "selected_model_id" in data:
model_id = data["selected_model_id"]
# 保存选中的模型 ID 到数据库
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
@ -2370,9 +2517,9 @@ async def process_chat_response(
else:
choices = data.get("choices", [])
# 17421
# === 5. 处理 usage 和 timings 信息 ===
usage = data.get("usage", {}) or {}
usage.update(data.get("timings", {})) # llama.cpp
usage.update(data.get("timings", {})) # llama.cpp timings
if usage:
await event_emitter(
{
@ -2383,6 +2530,7 @@ async def process_chat_response(
}
)
# === 6. 处理错误响应 ===
if not choices:
error = data.get("error", {})
if error:
@ -2396,9 +2544,11 @@ async def process_chat_response(
)
continue
# === 7. 提取 delta增量内容===
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
# === 8. 处理工具调用Tool Calls===
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get(
@ -2406,7 +2556,7 @@ async def process_chat_response(
)
if tool_call_index is not None:
# Check if the tool call already exists
# 查找已存在的工具调用
current_response_tool_call = None
for (
response_tool_call
@ -2421,7 +2571,7 @@ async def process_chat_response(
break
if current_response_tool_call is None:
# Add the new tool call
# 添加新的工具调用
delta_tool_call.setdefault(
"function", {}
)
@ -2435,7 +2585,7 @@ async def process_chat_response(
delta_tool_call
)
else:
# Update the existing tool call
# 更新已存在的工具调用(累积 name 和 arguments
delta_name = delta_tool_call.get(
"function", {}
).get("name")
@ -2457,14 +2607,17 @@ async def process_chat_response(
"arguments"
] += delta_arguments
# === 9. 处理文本内容 ===
value = delta.get("content")
# === 10. 处理推理内容Reasoning Content===
reasoning_content = (
delta.get("reasoning_content")
or delta.get("reasoning")
or delta.get("thinking")
)
if reasoning_content:
# 创建或更新 reasoning 内容块
if (
not content_blocks
or content_blocks[-1]["type"] != "reasoning"
@ -2483,6 +2636,7 @@ async def process_chat_response(
else:
reasoning_block = content_blocks[-1]
# 累积推理内容
reasoning_block["content"] += reasoning_content
data = {
@ -2491,7 +2645,9 @@ async def process_chat_response(
)
}
# === 11. 处理普通文本内容 ===
if value:
# 如果上一个块是 reasoning标记结束并创建新的文本块
if (
content_blocks
and content_blocks[-1]["type"]
@ -2515,6 +2671,7 @@ async def process_chat_response(
}
)
# 累积文本内容
content = f"{content}{value}"
if not content_blocks:
content_blocks.append(
@ -2528,6 +2685,8 @@ async def process_chat_response(
content_blocks[-1]["content"] + value
)
# === 12. 检测并处理特殊标签 ===
# 12.1 Reasoning 标签检测(如 <think>
if DETECT_REASONING_TAGS:
content, content_blocks, _ = (
tag_content_handler(
@ -2538,6 +2697,7 @@ async def process_chat_response(
)
)
# Solution 标签检测
content, content_blocks, _ = (
tag_content_handler(
"solution",
@ -2547,6 +2707,7 @@ async def process_chat_response(
)
)
# 12.2 Code Interpreter 标签检测
if DETECT_CODE_INTERPRETER:
content, content_blocks, end = (
tag_content_handler(
@ -2557,11 +2718,13 @@ async def process_chat_response(
)
)
# 如果检测到结束标签,停止流式处理
if end:
break
# === 13. 实时保存消息(可选)===
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
# 保存到数据库
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
@ -2572,18 +2735,22 @@ async def process_chat_response(
},
)
else:
# 准备待发送的数据
data = {
"content": serialize_content_blocks(
content_blocks
),
}
# === 14. 流式推送控制 ===
if delta:
delta_count += 1
last_delta_data = data
# 达到阈值时刷新
if delta_count >= delta_chunk_size:
await flush_pending_delta_data(delta_chunk_size)
else:
# 非 delta 数据立即发送
await event_emitter(
{
"type": "chat:completion",
@ -2591,24 +2758,30 @@ async def process_chat_response(
}
)
except Exception as e:
# 处理流结束标记
done = "data: [DONE]" in line
if done:
pass
else:
log.debug(f"Error: {e}")
continue
# === 15. 刷新剩余的 delta 数据 ===
await flush_pending_delta_data()
# === 16. 清理内容块 ===
if content_blocks:
# Clean up the last text block
# 清理最后一个文本块(移除尾部空白)
if content_blocks[-1]["type"] == "text":
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].strip()
# 如果为空则移除
if not content_blocks[-1]["content"]:
content_blocks.pop()
# 确保至少有一个空文本块
if not content_blocks:
content_blocks.append(
{
@ -2617,6 +2790,7 @@ async def process_chat_response(
}
)
# 标记最后一个 reasoning 块结束
if content_blocks[-1]["type"] == "reasoning":
reasoning_block = content_blocks[-1]
if reasoning_block.get("ended_at") is None:
@ -2626,9 +2800,11 @@ async def process_chat_response(
- reasoning_block["started_at"]
)
# === 17. 保存工具调用 ===
if response_tool_calls:
tool_calls.append(response_tool_calls)
# === 18. 执行响应清理(关闭连接)===
if response.background:
await response.background()

View file

@ -67,6 +67,45 @@ def get_messages_content(messages: list[dict]) -> str:
)
def extract_timestamped_messages(raw_msgs: list[dict]) -> list[dict]:
"""
将消息列表转换为统一的字典结构便于下游持久化/审计
Args:
raw_msgs (list[dict]): OpenAI 格式的消息列表
Returns:
list[dict]: 每条消息包含 rolecontenttimestamp 字段
"""
messages: list[dict] = []
for msg in raw_msgs:
if not isinstance(msg, dict):
continue
ts = (
msg.get("createdAt")
or msg.get("created_at")
or msg.get("timestamp")
or msg.get("updated_at")
or msg.get("updatedAt")
or 0
)
content = msg.get("content", "") or ""
if isinstance(content, list):
for item in content:
if item.get("type") == "text":
content = item.get("text", "")
break
messages.append(
{
"role": msg.get("role", "assistant"),
"content": str(content),
"timestamp": int(ts),
}
)
return messages
def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
for message in reversed(messages):
if message["role"] == "user":