mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 04:45:19 +00:00
为后端若干对话补全相关函数注释;加入 backend/open_webui/memory/cross_window_memory.py:last_process_payload 接口
This commit is contained in:
parent
c47df68d56
commit
984bb5888e
7 changed files with 610 additions and 147 deletions
|
|
@ -469,6 +469,7 @@ from open_webui.utils.chat import (
|
|||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
from open_webui.utils.misc import get_message_list
|
||||
from open_webui.utils.embeddings import generate_embeddings
|
||||
from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
|
@ -1592,7 +1593,7 @@ async def chat_completion(
|
|||
)
|
||||
|
||||
# 8.2 调用 LLM 完成对话 (核心)
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
response = await chat_completion_handler(request, form_data, user, chatting_completion = True)
|
||||
|
||||
# 8.3 更新数据库:保存模型 ID 到消息记录
|
||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||
|
|
@ -1688,6 +1689,74 @@ generate_chat_completions = chat_completion
|
|||
generate_chat_completion = chat_completion
|
||||
|
||||
|
||||
@app.post("/api/chat/tutorial")
|
||||
async def chat_completion_tutorial(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
最小可正常工作的示例:复用核心聊天链路,支持直连/历史消息/流式返回。
|
||||
|
||||
- 入参与 /api/chat/completions 基本兼容,但不做额外后台任务。
|
||||
- 若提供 chat_id/message_id,会读取历史消息并可触发 WS 事件;否则直接返回 SSE。
|
||||
"""
|
||||
# 直连兼容
|
||||
model_item = form_data.pop("model_item", {})
|
||||
if model_item.get("direct"):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
# 确保模型可用
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
model_id = form_data.get("model") or next(iter(request.app.state.MODELS.keys()))
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Model not found"
|
||||
)
|
||||
model = request.app.state.MODELS[model_id]
|
||||
if user.role == "user":
|
||||
check_model_access(user, model)
|
||||
|
||||
# 取消息历史:优先用传入 messages,否则按 chat_id 取全量
|
||||
messages = form_data.get("messages") or []
|
||||
chat_id = form_data.get("chat_id")
|
||||
message_id = form_data.get("id") or form_data.get("message_id")
|
||||
if not messages and chat_id:
|
||||
messages_map = Chats.get_messages_map_by_chat_id(chat_id)
|
||||
if messages_map:
|
||||
if message_id and message_id in messages_map:
|
||||
messages = get_message_list(messages_map, message_id)
|
||||
else:
|
||||
messages = list(messages_map.values())
|
||||
|
||||
# 准备 metadata,便于复用后续事件/DB 逻辑
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"session_id": form_data.get("session_id"),
|
||||
"model": model,
|
||||
"direct": model_item.get("direct", False),
|
||||
"params": {},
|
||||
}
|
||||
request.state.metadata = metadata
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
# 核心链路:预处理 -> 调模型 -> 流式/非流式响应处理
|
||||
form_data["model"] = model_id
|
||||
form_data["messages"] = messages
|
||||
form_data["stream"] = True # 强制流式,便于示例
|
||||
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, user, metadata, model
|
||||
)
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks=None
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
|
|
|
|||
0
backend/open_webui/memory/__init__.py
Normal file
0
backend/open_webui/memory/__init__.py
Normal file
21
backend/open_webui/memory/cross_window_memory.py
Normal file
21
backend/open_webui/memory/cross_window_memory.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def last_process_payload(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
messages: List[Dict],
|
||||
):
|
||||
"""
|
||||
对调用 LLM API 前的上下文信息 进行加工。
|
||||
|
||||
Args:
|
||||
user_id (str): 用户的唯一 ID。
|
||||
session_id (str): 该用户本次对话/会话的 ID。
|
||||
messages (List[Dict]): 该用户在该对话下的聊天消息列表,
|
||||
形如 {"role": "system|user|assistant", "content": "...", "timestamp": 0}。
|
||||
"""
|
||||
print("user_id:", user_id)
|
||||
print("session_id:", session_id)
|
||||
print("messages:", messages)
|
||||
|
|
@ -34,6 +34,8 @@ from open_webui.env import (
|
|||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.memory.cross_window_memory import last_process_payload
|
||||
from open_webui.utils.misc import extract_timestamped_messages
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
|
@ -807,61 +809,123 @@ async def generate_chat_completion(
|
|||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: Optional[bool] = False,
|
||||
chatting_completion: bool = False
|
||||
):
|
||||
"""
|
||||
OpenAI 兼容的聊天完成端点 - 直接转发请求到 OpenAI API 或兼容服务
|
||||
|
||||
这是 OpenAI router 中的底层 API 调用函数,负责:
|
||||
1. 应用模型配置(base_model_id, system prompt, 参数覆盖)
|
||||
2. 验证用户权限(模型访问控制)
|
||||
3. 处理 Azure OpenAI 特殊格式转换
|
||||
4. 处理推理模型(reasoning model)特殊逻辑
|
||||
5. 转发 HTTP 请求到上游 API(支持流式和非流式)
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
form_data: OpenAI 格式的聊天请求
|
||||
- model: 模型 ID
|
||||
- messages: 消息列表
|
||||
- stream: 是否流式响应
|
||||
- temperature, max_tokens 等参数
|
||||
user: 已验证的用户对象
|
||||
bypass_filter: 是否绕过权限检查
|
||||
|
||||
Returns:
|
||||
- 流式: StreamingResponse (SSE)
|
||||
- 非流式: dict (OpenAI JSON 格式)
|
||||
|
||||
Raises:
|
||||
HTTPException 403: 无权限访问模型
|
||||
HTTPException 404: 模型不存在
|
||||
HTTPException 500: 上游 API 连接失败
|
||||
"""
|
||||
|
||||
# print("user:", user)
|
||||
# user:
|
||||
# id='55f85fb0-4aca-48bc-aea1-afce50ac989e'
|
||||
# name='gaofeng1'
|
||||
# email='h.summit1628935449@gmail.com'
|
||||
# username=None
|
||||
# role='user'
|
||||
# profile_image_url='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAEa0lEQVR4Xu2aW0hUQRzG/7vr7noLSZGgIIMsobAgjRKCrhSJD0EXgpSCopcK6qmCHgyC6Kmo6CGMiAgqeqgQ6SG7EYhoQQpdSCKppBARMXUv7m47C7vr2ZX2XHa378S3b8uZmfPN95tvzpw54xhvK48IfzAOOAgEhkVMCIFg8SAQMB4EQiBoDoDp4TOEQMAcAJPDhBAImANgcpgQAgFzAEwOE0IgYA6AyWFCCATMATA5TAiBgDkAJocJIRAwB8DkMCEEAuYAmBwmhEDAHACTw4QQCJgDYHKYEAIBcwBMDhNCIGAOgMlhQggEzAEwOUwIgVh3wF3TIq55a8VZvlwchRXiKJmvaTQyMSQR34iEfnVLcOC+hIffWL9pnlqwVUK8DRekoKoxDUAmr2Jg+q/K9GBHpqL//LotgDgr66Ro/TVxlFWbNyw0JcFPd8TfddJ8G3moCQ9ETU/ehvMirqI0O8Ij/RIeH5TQcG/imsNTJs6yJeKqXDVrkoLv26ChQANR01PhxutaGNGRPv2tUwJ9lzM+G9wrjolnaXNasvyvj0fTcjsP4934LaCBlOzq1prpHxV/z1nDZhY1dUQXAWsS7kTGBmTiQfK/cdtyVwMWiHqAu5cdSvY8mgx/12nDMOINlOzt10xhvqctkA95WCClzQMi3rkJIIF3FyXQe8700PTUnxHPyhOJ+qEfz2TqyW7T7eWqIiSQVPOyNcWolMTfT0JDL5kQvaMqdc5HXxnp7ZeecpAJSZ2uJh9vzbii0tNZO5SBAxJb6m5JLknVNsjE3Vo7eJkVjXBA1LuDd3VronPq5W/y4YasdNYOjcADmf7aLr7O/XbwMisa/ysgpQdHDJny+0aFofL5KEwg+XDZwD3ggait86n2Rl1dYkJ02WSsUGx3d90lU/tOakHwt1/Bwu2aPS1OWTrZlB74rtnhzZZxhZtvScGipoSKbLWrs1u6isFNWUp18Y4X4qxIvntY3ceKO0EgusZEeqHUnd5s7WURiEkgqlrq9om/p1WCfVcstChCIBbsm+17iO/5YUs7tARiAYiqmvbF0MJHqhjgmn05WSxY7KamOuRDPa5QnTYp3nZP86FKXdN7rEfVd1fvEffinWltqHa4yjIxlNTurxrdqYfhVFNqJzg0/FbCY58lEhhLtO6qrBfnnCrNSk1za3VQ4ssj8b06YkJRbqtAJ2RmUrx1p8S1YJNlN/Smy/KNTDZgCyDxvqm0uGuPat62dfU7mojQzy4JfrhpaVGg614WC9kKyMzEqGeDenl0qrO9qScao8eFYt/ORz/GDtFZXS5b9NhQdVsCMdRDmxUmEDBgBEIgYA6AyWFCCATMATA5TAiBgDkAJocJIRAwB8DkMCEEAuYAmBwmhEDAHACTw4QQCJgDYHKYEAIBcwBMDhNCIGAOgMlhQggEzAEwOUwIgYA5ACaHCSEQMAfA5DAhBALmAJgcJoRAwBwAk8OEEAiYA2BymBACAXMATA4TQiBgDoDJYUIIBMwBMDl/AP79PANEXNbbAAAAAElFTkSuQmCC'
|
||||
# bio=None
|
||||
# gender=None
|
||||
# date_of_birth=None
|
||||
# info=None
|
||||
# settings=UserSettings(ui={'memory': True})
|
||||
# api_key=None
|
||||
# oauth_sub=None
|
||||
# last_active_at=1763997832
|
||||
# updated_at=1763971141
|
||||
# created_at=1763874812
|
||||
|
||||
# === 1. 权限检查配置 ===
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
idx = 0
|
||||
idx = 0 # 用于标识使用哪个 OPENAI_API_BASE_URL
|
||||
|
||||
# === 2. 准备 Payload 和提取元数据 ===
|
||||
payload = {**form_data}
|
||||
metadata = payload.pop("metadata", None)
|
||||
metadata = payload.pop("metadata", None) # 移除内部元数据,不发送给上游 API
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# === 3. 应用模型配置和权限检查 ===
|
||||
# Check model info and override the payload
|
||||
if model_info:
|
||||
# 3.1 如果配置了 base_model_id,替换为底层模型 ID
|
||||
# 例如:自定义模型 "my-gpt4" → 实际调用 "gpt-4-turbo"
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
model_id = model_info.base_model_id
|
||||
|
||||
# 3.2 应用模型参数(temperature, max_tokens 等)
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
system = params.pop("system", None) # 提取 system prompt
|
||||
|
||||
# 应用模型参数到 payload(覆盖用户传入的参数)
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
# 注入或替换 system prompt
|
||||
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# 3.3 权限检查:验证用户是否有权限访问该模型
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
user.id == model_info.user_id # 用户是模型创建者
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
) # 或用户在访问控制列表中
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
elif not bypass_filter:
|
||||
# 如果模型信息不存在且未绕过过滤器,只有管理员可访问
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
await get_all_models(request, user=user)
|
||||
# === 4. 查找 OpenAI API 配置 ===
|
||||
await get_all_models(request, user=user) # 刷新模型列表
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
idx = model["urlIdx"] # 获取 API 基础 URL 索引
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# === 5. 获取 API 配置并处理 prefix_id ===
|
||||
# Get the API config for the model
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(idx),
|
||||
|
|
@ -870,10 +934,13 @@ async def generate_chat_completion(
|
|||
), # Legacy support
|
||||
)
|
||||
|
||||
# 移除模型 ID 前缀(如果配置了 prefix_id)
|
||||
# 例如:模型 ID "custom.gpt-4" → 发送给 API 的是 "gpt-4"
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
# === 6. Pipeline 模式:注入用户信息 ===
|
||||
# Add user info to the payload if the model is a pipeline
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {
|
||||
|
|
@ -886,32 +953,40 @@ async def generate_chat_completion(
|
|||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# === 7. 推理模型特殊处理 ===
|
||||
# Check if model is a reasoning model that needs special handling
|
||||
if is_openai_reasoning_model(payload["model"]):
|
||||
# 推理模型(如 o1)使用 max_completion_tokens 而非 max_tokens
|
||||
payload = openai_reasoning_model_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# 非 OpenAI 官方 API:向后兼容,将 max_completion_tokens 转为 max_tokens
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
payload["max_tokens"] = payload["max_completion_tokens"]
|
||||
del payload["max_completion_tokens"]
|
||||
|
||||
# 避免同时存在 max_tokens 和 max_completion_tokens
|
||||
if "max_tokens" in payload and "max_completion_tokens" in payload:
|
||||
del payload["max_tokens"]
|
||||
|
||||
# === 8. 转换 logit_bias 格式 ===
|
||||
# Convert the modified body back to JSON
|
||||
if "logit_bias" in payload:
|
||||
payload["logit_bias"] = json.loads(
|
||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||
)
|
||||
|
||||
# === 9. 准备请求头和 Cookies ===
|
||||
headers, cookies = await get_headers_and_cookies(
|
||||
request, url, key, api_config, metadata, user=user
|
||||
)
|
||||
|
||||
# === 10. Azure OpenAI 特殊处理 ===
|
||||
if api_config.get("azure", False):
|
||||
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||
request_url, payload = convert_to_azure_payload(url, payload, api_version)
|
||||
|
||||
# 只有在非 Azure Entra ID 认证时才设置 api-key header
|
||||
# Only set api-key header if not using Azure Entra ID authentication
|
||||
auth_type = api_config.get("auth_type", "bearer")
|
||||
if auth_type not in ("azure_ad", "microsoft_entra_id"):
|
||||
|
|
@ -920,16 +995,34 @@ async def generate_chat_completion(
|
|||
headers["api-version"] = api_version
|
||||
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
||||
else:
|
||||
# 标准 OpenAI 兼容 API
|
||||
request_url = f"{url}/chat/completions"
|
||||
|
||||
payload = json.dumps(payload)
|
||||
if chatting_completion:
|
||||
try:
|
||||
# 可选钩子:在发送到上游前记录/审计 payload,需自行实现 last_process_payload
|
||||
log.debug(
|
||||
f"chatting_completion hook user={user.id} chat_id={metadata.get('chat_id')} model={payload.get('model')}"
|
||||
)
|
||||
|
||||
last_process_payload(
|
||||
user_id = user.id,
|
||||
session_id = metadata.get("chat_id"),
|
||||
messages = extract_timestamped_messages(payload.get("messages", [])),
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"chatting_completion 钩子执行失败: {e}")
|
||||
|
||||
payload = json.dumps(payload) # 序列化为 JSON 字符串
|
||||
|
||||
# === 11. 初始化请求状态变量 ===
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
response = None
|
||||
|
||||
try:
|
||||
# === 12. 发起 HTTP 请求到上游 API ===
|
||||
session = aiohttp.ClientSession(
|
||||
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
)
|
||||
|
|
@ -943,8 +1036,10 @@ async def generate_chat_completion(
|
|||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
|
||||
# === 13. 处理响应 ===
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
# 流式响应:直接转发 SSE 流
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
|
|
@ -955,19 +1050,22 @@ async def generate_chat_completion(
|
|||
),
|
||||
)
|
||||
else:
|
||||
# 非流式响应:解析 JSON
|
||||
try:
|
||||
response = await r.json()
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
response = await r.text()
|
||||
response = await r.text() # 如果 JSON 解析失败,返回纯文本
|
||||
|
||||
# 处理错误响应
|
||||
if r.status >= 400:
|
||||
if isinstance(response, (dict, list)):
|
||||
return JSONResponse(status_code=r.status, content=response)
|
||||
else:
|
||||
return PlainTextResponse(status_code=r.status, content=response)
|
||||
|
||||
return response
|
||||
return response # 成功响应
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
|
|
@ -976,6 +1074,8 @@ async def generate_chat_completion(
|
|||
detail="CyberLover: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
# === 14. 清理资源 ===
|
||||
# 非流式响应需要手动关闭连接(流式响应在 BackgroundTask 中处理)
|
||||
if not streaming:
|
||||
await cleanup_response(r, session)
|
||||
|
||||
|
|
|
|||
|
|
@ -167,124 +167,182 @@ async def generate_chat_completion(
|
|||
form_data: dict,
|
||||
user: Any,
|
||||
bypass_filter: bool = False,
|
||||
chatting_completion: bool = False
|
||||
):
|
||||
log.debug(f"generate_chat_completion: {form_data}")
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
"""
|
||||
聊天完成生成函数 - 根据模型类型分发到不同的底层 API 处理器
|
||||
|
||||
这是聊天完成的核心路由函数,负责:
|
||||
1. 验证模型存在性和用户权限
|
||||
2. 处理 Direct 模式(直连外部 API)
|
||||
3. 处理 Arena 模式(随机选择模型进行对比)
|
||||
4. 根据模型类型分发到对应处理器:
|
||||
- Pipe: Pipeline 插件函数
|
||||
- Ollama: Ollama 本地模型
|
||||
- OpenAI: OpenAI 兼容 API (含 Claude, Gemini 等)
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
form_data: OpenAI 格式的聊天请求数据
|
||||
user: 用户对象
|
||||
bypass_filter: 是否绕过权限和 Pipeline Filter 检查
|
||||
|
||||
Returns:
|
||||
- 流式: StreamingResponse (SSE 格式)
|
||||
- 非流式: dict (OpenAI 兼容格式)
|
||||
|
||||
Raises:
|
||||
Exception: 模型不存在或无权限访问
|
||||
"""
|
||||
log.debug(f"generate_chat_completion: {form_data}")
|
||||
|
||||
# === 1. 权限检查配置 ===
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True # 全局配置:绕过所有权限检查
|
||||
|
||||
# === 2. 合并元数据 ===
|
||||
# 从 request.state.metadata 获取上游传递的元数据(chat_id, user_id 等)
|
||||
if hasattr(request.state, "metadata"):
|
||||
if "metadata" not in form_data:
|
||||
form_data["metadata"] = request.state.metadata
|
||||
else:
|
||||
# 合并,request.state.metadata 优先级更高
|
||||
form_data["metadata"] = {
|
||||
**form_data["metadata"],
|
||||
**request.state.metadata,
|
||||
}
|
||||
|
||||
# === 3. 确定模型列表来源 ===
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
# Direct 模式:用户直接提供外部 API 配置(如 OpenAI API Key)
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
log.debug(f"direct connection to model: {models}")
|
||||
else:
|
||||
# 标准模式:使用平台内置模型列表
|
||||
models = request.app.state.MODELS
|
||||
|
||||
# === 4. 验证模型存在性 ===
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
|
||||
model = models[model_id]
|
||||
|
||||
# === 5. Direct 模式分支:直连外部 API ===
|
||||
if getattr(request.state, "direct", False):
|
||||
return await generate_direct_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
else:
|
||||
# Check if user has access to the model
|
||||
# === 6. 标准模式:检查用户权限 ===
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
check_model_access(user, model) # 验证 RBAC 权限
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model.get("owned_by") == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
# === 7. Arena 模式:随机选择模型进行盲测对比 ===
|
||||
if False:
|
||||
if model.get("owned_by") == "arena":
|
||||
# 获取候选模型列表
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
# 如果是排除模式,反选模型列表
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
form_data["model"] = selected_model_id
|
||||
# 随机选择一个模型
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
# 未指定则从所有非 Arena 模型中随机选择
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
|
||||
if form_data.get("stream") == True:
|
||||
# 替换模型 ID
|
||||
form_data["model"] = selected_model_id
|
||||
|
||||
async def stream_wrapper(stream):
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
# 流式响应:在首个 chunk 中注入 selected_model_id
|
||||
if form_data.get("stream") == True:
|
||||
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
async def stream_wrapper(stream):
|
||||
"""在流式响应前添加选中的模型 ID"""
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
# 递归调用自身,绕过 Arena 逻辑
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
# 非流式响应:直接在结果中添加 selected_model_id
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
# === 8. Pipeline 模式:调用自定义 Python 函数 ===
|
||||
if False:
|
||||
if model.get("pipe"):
|
||||
return await generate_function_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model.get("owned_by") == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
# === 9. Ollama 模式:调用本地 Ollama 服务 ===
|
||||
if False:
|
||||
if model.get("owned_by") == "ollama":
|
||||
# 转换 OpenAI 格式 → Ollama 格式
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
|
||||
# 流式响应:转换 Ollama SSE → OpenAI SSE
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
# 非流式响应:转换 Ollama JSON → OpenAI JSON
|
||||
return convert_response_ollama_to_openai(response)
|
||||
|
||||
# === 10. OpenAI 兼容模式:调用 OpenAI API 或兼容服务 ===
|
||||
# >>>
|
||||
return await generate_openai_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
chatting_completion = chatting_completion
|
||||
)
|
||||
|
||||
|
||||
chat_completion = generate_chat_completion
|
||||
|
|
|
|||
|
|
@ -994,25 +994,52 @@ def apply_params_to_form_data(form_data, model):
|
|||
|
||||
|
||||
async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
# Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
|
||||
# -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
|
||||
# -> Chat Files
|
||||
"""
|
||||
处理聊天请求的 Payload - 执行 Pipeline、Filter、功能增强和工具注入
|
||||
|
||||
这是聊天请求预处理的核心函数,按以下顺序执行:
|
||||
1. Pipeline Inlet (管道入口) - 自定义 Python 插件预处理
|
||||
2. Filter Inlet (过滤器入口) - 函数过滤器预处理
|
||||
3. Chat Memory (记忆) - 注入历史对话记忆
|
||||
4. Chat Web Search (网页搜索) - 执行网络搜索并注入结果
|
||||
5. Chat Image Generation (图像生成) - 处理图像生成请求
|
||||
6. Chat Code Interpreter (代码解释器) - 注入代码执行提示词
|
||||
7. Chat Tools Function Calling (工具调用) - 处理函数/工具调用
|
||||
8. Chat Files (文件处理) - 处理上传文件、知识库文件、RAG 检索
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
form_data: OpenAI 格式的聊天请求数据
|
||||
user: 用户对象
|
||||
metadata: 元数据(chat_id, message_id, tool_ids, files 等)
|
||||
model: 模型配置对象
|
||||
|
||||
Returns:
|
||||
tuple: (form_data, metadata, events)
|
||||
- form_data: 处理后的请求数据
|
||||
- metadata: 更新后的元数据
|
||||
- events: 需要发送给前端的事件列表(如引用来源)
|
||||
"""
|
||||
# === 1. 应用模型参数到请求 ===
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
# === 2. 处理 System Prompt 变量替换 ===
|
||||
system_message = get_system_message(form_data.get("messages", []))
|
||||
if system_message: # Chat Controls/User Settings
|
||||
try:
|
||||
# 替换 system prompt 中的变量(如 {{USER_NAME}}, {{CURRENT_DATE}})
|
||||
form_data = apply_system_prompt_to_body(
|
||||
system_message.get("content"), form_data, metadata, user, replace=True
|
||||
) # Required to handle system prompt variables
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
# === 3. 初始化事件发射器和回调 ===
|
||||
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
|
||||
event_call = get_event_call(metadata) # 事件调用函数
|
||||
|
||||
# === 4. 获取 OAuth Token ===
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
|
|
@ -1023,9 +1050,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
# === 5. 构建额外参数(供 Pipeline/Filter/Tools 使用)===
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__event_emitter__": event_emitter, # 用于向前端发送实时事件
|
||||
"__event_call__": event_call, # 用于调用事件回调
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
|
|
@ -1033,15 +1061,19 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
"__oauth_token__": oauth_token,
|
||||
}
|
||||
|
||||
# === 6. 确定模型列表和任务模型 ===
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
# Direct 模式:使用用户直连的模型
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
# 标准模式:使用平台所有模型
|
||||
models = request.app.state.MODELS
|
||||
|
||||
# 获取任务模型 ID(用于工具调用、标题生成等后台任务)
|
||||
task_model_id = get_task_model_id(
|
||||
form_data["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
|
|
@ -1049,10 +1081,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
models,
|
||||
)
|
||||
|
||||
events = []
|
||||
sources = []
|
||||
# === 7. 初始化事件和引用来源列表 ===
|
||||
events = [] # 需要发送给前端的事件(如 sources 引用)
|
||||
sources = [] # RAG 检索到的文档来源
|
||||
|
||||
# Folder "Project" handling
|
||||
# === 8. Folder "Project" 处理 - 注入文件夹的 System Prompt 和文件 ===
|
||||
# Check if the request has chat_id and is inside of a folder
|
||||
chat_id = metadata.get("chat_id", None)
|
||||
if chat_id and user:
|
||||
|
|
@ -1061,21 +1094,24 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
|
||||
|
||||
if folder and folder.data:
|
||||
# 注入文件夹的 system prompt
|
||||
if "system_prompt" in folder.data:
|
||||
form_data = apply_system_prompt_to_body(
|
||||
folder.data["system_prompt"], form_data, metadata, user
|
||||
)
|
||||
# 注入文件夹关联的文件
|
||||
if "files" in folder.data:
|
||||
form_data["files"] = [
|
||||
*folder.data["files"],
|
||||
*form_data.get("files", []),
|
||||
]
|
||||
|
||||
# Model "Knowledge" handling
|
||||
# === 9. Model "Knowledge" 处理 - 注入模型绑定的知识库 ===
|
||||
user_message = get_last_user_message(form_data["messages"])
|
||||
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
|
||||
|
||||
if model_knowledge:
|
||||
# 向前端发送知识库搜索状态
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
|
|
@ -1089,6 +1125,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
|
||||
knowledge_files = []
|
||||
for item in model_knowledge:
|
||||
# 处理旧格式的 collection_name
|
||||
if item.get("collection_name"):
|
||||
knowledge_files.append(
|
||||
{
|
||||
|
|
@ -1097,6 +1134,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
"legacy": True,
|
||||
}
|
||||
)
|
||||
# 处理新格式的 collection_names(多个集合)
|
||||
elif item.get("collection_names"):
|
||||
knowledge_files.append(
|
||||
{
|
||||
|
|
@ -1109,12 +1147,14 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
else:
|
||||
knowledge_files.append(item)
|
||||
|
||||
# 合并模型知识库文件和用户上传文件
|
||||
files = form_data.get("files", [])
|
||||
files.extend(knowledge_files)
|
||||
form_data["files"] = files
|
||||
|
||||
variables = form_data.pop("variables", None)
|
||||
|
||||
# === 10. Pipeline Inlet 处理 - 执行自定义 Python 插件 ===
|
||||
# Process the form_data through the pipeline
|
||||
try:
|
||||
form_data = await process_pipeline_inlet_filter(
|
||||
|
|
@ -1123,6 +1163,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# === 11. Filter Inlet 处理 - 执行函数过滤器 ===
|
||||
try:
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
|
|
@ -1141,23 +1182,28 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
except Exception as e:
|
||||
raise Exception(f"{e}")
|
||||
|
||||
# === 12. 功能增强处理 (Features) ===
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
# 12.1 记忆功能 - 注入历史对话记忆
|
||||
if "memory" in features and features["memory"]:
|
||||
form_data = await chat_memory_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
# 12.2 网页搜索功能 - 执行网络搜索
|
||||
if "web_search" in features and features["web_search"]:
|
||||
form_data = await chat_web_search_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
# 12.3 图像生成功能 - 处理图像生成请求
|
||||
if "image_generation" in features and features["image_generation"]:
|
||||
form_data = await chat_image_generation_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
# 12.4 代码解释器功能 - 注入代码执行提示词
|
||||
if "code_interpreter" in features and features["code_interpreter"]:
|
||||
form_data["messages"] = add_or_update_user_message(
|
||||
(
|
||||
|
|
@ -1168,33 +1214,33 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
form_data["messages"],
|
||||
)
|
||||
|
||||
# === 13. 提取工具和文件信息 ===
|
||||
tool_ids = form_data.pop("tool_ids", None)
|
||||
files = form_data.pop("files", None)
|
||||
|
||||
# === 14. 文件夹类型文件展开 ===
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
# TODO: re-enable URL extraction from prompt
|
||||
# urls = []
|
||||
# if prompt and len(prompt or "") < 500 and (not files or len(files) == 0):
|
||||
# urls = extract_urls(prompt)
|
||||
|
||||
if files:
|
||||
if not files:
|
||||
files = []
|
||||
|
||||
for file_item in files:
|
||||
# 如果文件类型是 folder,展开其包含的所有文件
|
||||
if file_item.get("type", "file") == "folder":
|
||||
# Get folder files
|
||||
folder_id = file_item.get("id", None)
|
||||
if folder_id:
|
||||
folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id)
|
||||
if folder and folder.data and "files" in folder.data:
|
||||
# 移除文件夹项,添加文件夹内的文件
|
||||
files = [f for f in files if f.get("id", None) != folder_id]
|
||||
files = [*files, *folder.data["files"]]
|
||||
|
||||
# files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
|
||||
# Remove duplicate files based on their content
|
||||
# 去重文件(基于文件内容)
|
||||
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
||||
|
||||
# === 15. 更新元数据 ===
|
||||
metadata = {
|
||||
**metadata,
|
||||
"tool_ids": tool_ids,
|
||||
|
|
@ -1202,25 +1248,28 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
# === 16. 准备工具字典 ===
|
||||
# Server side tools
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
tool_ids = metadata.get("tool_ids", None) # 服务器端工具 ID
|
||||
# Client side tools
|
||||
direct_tool_servers = metadata.get("tool_servers", None)
|
||||
direct_tool_servers = metadata.get("tool_servers", None) # 客户端直连工具服务器
|
||||
|
||||
log.debug(f"{tool_ids=}")
|
||||
log.debug(f"{direct_tool_servers=}")
|
||||
|
||||
tools_dict = {}
|
||||
tools_dict = {} # 所有工具的字典
|
||||
|
||||
mcp_clients = {}
|
||||
mcp_tools_dict = {}
|
||||
mcp_clients = {} # MCP (Model Context Protocol) 客户端
|
||||
mcp_tools_dict = {} # MCP 工具字典
|
||||
|
||||
if tool_ids:
|
||||
for tool_id in tool_ids:
|
||||
# === 16.1 处理 MCP (Model Context Protocol) 工具 ===
|
||||
if tool_id.startswith("server:mcp:"):
|
||||
try:
|
||||
server_id = tool_id[len("server:mcp:") :]
|
||||
|
||||
# 查找 MCP 服务器连接配置
|
||||
mcp_server_connection = None
|
||||
for (
|
||||
server_connection
|
||||
|
|
@ -1236,6 +1285,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
log.error(f"MCP server with id {server_id} not found")
|
||||
continue
|
||||
|
||||
# 处理认证类型
|
||||
auth_type = mcp_server_connection.get("auth_type", "")
|
||||
|
||||
headers = {}
|
||||
|
|
@ -1244,7 +1294,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
f"Bearer {mcp_server_connection.get('key', '')}"
|
||||
)
|
||||
elif auth_type == "none":
|
||||
# No authentication
|
||||
# 无需认证
|
||||
pass
|
||||
elif auth_type == "session":
|
||||
headers["Authorization"] = (
|
||||
|
|
@ -1273,16 +1323,19 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
log.error(f"Error getting OAuth token: {e}")
|
||||
oauth_token = None
|
||||
|
||||
# 连接到 MCP 服务器
|
||||
mcp_clients[server_id] = MCPClient()
|
||||
await mcp_clients[server_id].connect(
|
||||
url=mcp_server_connection.get("url", ""),
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
# 获取 MCP 工具列表并注册
|
||||
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||||
for tool_spec in tool_specs:
|
||||
|
||||
def make_tool_function(client, function_name):
|
||||
"""为每个 MCP 工具创建异步调用函数"""
|
||||
async def tool_function(**kwargs):
|
||||
return await client.call_tool(
|
||||
function_name,
|
||||
|
|
@ -1295,6 +1348,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
mcp_clients[server_id], tool_spec["name"]
|
||||
)
|
||||
|
||||
# 注册 MCP 工具到工具字典
|
||||
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
||||
"spec": {
|
||||
**tool_spec,
|
||||
|
|
@ -1309,6 +1363,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
log.debug(e)
|
||||
continue
|
||||
|
||||
# === 16.2 获取标准工具(Function Tools)===
|
||||
tools_dict = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
|
|
@ -1320,9 +1375,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
# 合并 MCP 工具
|
||||
if mcp_tools_dict:
|
||||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||||
|
||||
# === 16.3 处理客户端直连工具服务器 ===
|
||||
if direct_tool_servers:
|
||||
for tool_server in direct_tool_servers:
|
||||
tool_specs = tool_server.pop("specs", [])
|
||||
|
|
@ -1334,19 +1391,21 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
"server": tool_server,
|
||||
}
|
||||
|
||||
# 保存 MCP 客户端到元数据(用于最后清理)
|
||||
if mcp_clients:
|
||||
metadata["mcp_clients"] = mcp_clients
|
||||
|
||||
# === 17. 工具调用处理 ===
|
||||
if tools_dict:
|
||||
if metadata.get("params", {}).get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
# 原生函数调用模式:直接传递给 LLM
|
||||
metadata["tools"] = tools_dict
|
||||
form_data["tools"] = [
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools_dict.values()
|
||||
]
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
# 默认模式:通过 Prompt 实现工具调用
|
||||
try:
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, extra_params, user, models, tools_dict
|
||||
|
|
@ -1355,6 +1414,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# === 18. 文件处理 - RAG 检索 ===
|
||||
try:
|
||||
form_data, flags = await chat_completion_files_handler(
|
||||
request, form_data, extra_params, user
|
||||
|
|
@ -1363,11 +1423,13 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# === 19. 构建上下文字符串并注入到消息 ===
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(sources) > 0:
|
||||
context_string = ""
|
||||
citation_idx_map = {}
|
||||
citation_idx_map = {} # 引用索引映射(文档 ID → 引用编号)
|
||||
|
||||
# 遍历所有来源,构建上下文字符串
|
||||
for source in sources:
|
||||
if "document" in source:
|
||||
for document_text, document_metadata in zip(
|
||||
|
|
@ -1380,9 +1442,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
or "N/A"
|
||||
)
|
||||
|
||||
# 为每个来源分配唯一的引用编号
|
||||
if source_id not in citation_idx_map:
|
||||
citation_idx_map[source_id] = len(citation_idx_map) + 1
|
||||
|
||||
# 构建 XML 格式的来源标签
|
||||
context_string += (
|
||||
f'<source id="{citation_idx_map[source_id]}"'
|
||||
+ (f' name="{source_name}"' if source_name else "")
|
||||
|
|
@ -1393,6 +1457,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
if prompt is None:
|
||||
raise Exception("No user message found")
|
||||
|
||||
# 使用 RAG 模板将上下文注入到用户消息中
|
||||
if context_string != "":
|
||||
form_data["messages"] = add_or_update_user_message(
|
||||
rag_template(
|
||||
|
|
@ -1404,6 +1469,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
append=False,
|
||||
)
|
||||
|
||||
# === 20. 整理引用来源并添加到事件 ===
|
||||
# If there are citations, add them to the data_items
|
||||
sources = [
|
||||
source
|
||||
|
|
@ -1415,6 +1481,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
if len(sources) > 0:
|
||||
events.append({"sources": sources})
|
||||
|
||||
# === 21. 完成知识库搜索状态 ===
|
||||
if model_knowledge:
|
||||
await event_emitter(
|
||||
{
|
||||
|
|
@ -1434,29 +1501,63 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
async def process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
):
|
||||
"""
|
||||
处理聊天响应 - 规范化/分发聊天完成响应
|
||||
|
||||
这是聊天响应后处理的核心函数,负责:
|
||||
1. 处理流式(SSE/ndjson)和非流式(JSON)响应
|
||||
2. 通过 WebSocket 发送事件到前端
|
||||
3. 持久化消息/错误/元数据到数据库
|
||||
4. 触发后台任务(标题生成、标签生成、Follow-ups)
|
||||
5. 流式响应时:消费上游数据块、重建内容、处理工具调用、转发增量
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
response: 上游响应(dict/JSONResponse/StreamingResponse)
|
||||
form_data: 原始请求 payload(可能被上游修改)
|
||||
user: 已验证的用户对象
|
||||
metadata: 早期收集的聊天/会话上下文
|
||||
model: 解析后的模型配置
|
||||
events: 需要发送的额外事件(如 sources 引用)
|
||||
tasks: 可选的后台任务(title/tags/follow-ups)
|
||||
|
||||
Returns:
|
||||
- 非流式: dict (OpenAI JSON 格式)
|
||||
- 流式: StreamingResponse (SSE/ndjson 格式)
|
||||
"""
|
||||
|
||||
# === 内部函数:后台任务处理器 ===
|
||||
async def background_tasks_handler():
|
||||
"""
|
||||
在响应完成后执行后台任务:
|
||||
1. Follow-ups 生成 - 生成后续问题建议
|
||||
2. Title 生成 - 自动生成聊天标题
|
||||
3. Tags 生成 - 自动生成聊天标签
|
||||
"""
|
||||
message = None
|
||||
messages = []
|
||||
|
||||
# 获取消息历史
|
||||
if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"):
|
||||
# 从数据库获取持久化的聊天历史
|
||||
messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
|
||||
message = messages_map.get(metadata["message_id"]) if messages_map else None
|
||||
|
||||
message_list = get_message_list(messages_map, metadata["message_id"])
|
||||
|
||||
# Remove details tags and files from the messages.
|
||||
# as get_message_list creates a new list, it does not affect
|
||||
# the original messages outside of this handler
|
||||
|
||||
# 清理消息内容:移除 details 标签和文件
|
||||
# get_message_list 创建新列表,不影响原始消息
|
||||
messages = []
|
||||
for message in message_list:
|
||||
content = message.get("content", "")
|
||||
# 处理多模态内容(图片 + 文本)
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
content = item["text"]
|
||||
break
|
||||
|
||||
# 移除 <details> 标签和 Markdown 图片
|
||||
if isinstance(content, str):
|
||||
content = re.sub(
|
||||
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
|
||||
|
|
@ -1468,21 +1569,21 @@ async def process_chat_response(
|
|||
messages.append(
|
||||
{
|
||||
**message,
|
||||
"role": message.get(
|
||||
"role", "assistant"
|
||||
), # Safe fallback for missing role
|
||||
"role": message.get("role", "assistant"), # 安全回退
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Local temp chat, get the model and message from the form_data
|
||||
# 临时聊天(local:):从 form_data 获取
|
||||
message = get_last_user_message_item(form_data.get("messages", []))
|
||||
messages = form_data.get("messages", [])
|
||||
if message:
|
||||
message["model"] = form_data.get("model")
|
||||
|
||||
# 执行后台任务
|
||||
if message and "model" in message:
|
||||
if tasks and messages:
|
||||
# === 任务 1: Follow-ups 生成 ===
|
||||
if (
|
||||
TASKS.FOLLOW_UP_GENERATION in tasks
|
||||
and tasks[TASKS.FOLLOW_UP_GENERATION]
|
||||
|
|
@ -1510,6 +1611,7 @@ async def process_chat_response(
|
|||
else:
|
||||
follow_ups_string = ""
|
||||
|
||||
# 提取 JSON 对象(从第一个 { 到最后一个 })
|
||||
follow_ups_string = follow_ups_string[
|
||||
follow_ups_string.find("{") : follow_ups_string.rfind("}")
|
||||
+ 1
|
||||
|
|
@ -1519,6 +1621,7 @@ async def process_chat_response(
|
|||
follow_ups = json.loads(follow_ups_string).get(
|
||||
"follow_ups", []
|
||||
)
|
||||
# 通过 WebSocket 发送 Follow-ups
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message:follow_ups",
|
||||
|
|
@ -1528,6 +1631,7 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
# 持久化到数据库
|
||||
if not metadata.get("chat_id", "").startswith("local:"):
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
|
|
@ -1540,9 +1644,9 @@ async def process_chat_response(
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
if not metadata.get("chat_id", "").startswith(
|
||||
"local:"
|
||||
): # Only update titles and tags for non-temp chats
|
||||
# === 任务 2 & 3: 标题和标签生成(仅非临时聊天)===
|
||||
if not metadata.get("chat_id", "").startswith("local:"):
|
||||
# 任务 2: 标题生成
|
||||
if (
|
||||
TASKS.TITLE_GENERATION in tasks
|
||||
and tasks[TASKS.TITLE_GENERATION]
|
||||
|
|
@ -1579,6 +1683,7 @@ async def process_chat_response(
|
|||
else:
|
||||
title_string = ""
|
||||
|
||||
# 提取 JSON 对象
|
||||
title_string = title_string[
|
||||
title_string.find("{") : title_string.rfind("}") + 1
|
||||
]
|
||||
|
|
@ -1593,16 +1698,19 @@ async def process_chat_response(
|
|||
if not title:
|
||||
title = messages[0].get("content", user_message)
|
||||
|
||||
# 更新数据库
|
||||
Chats.update_chat_title_by_id(
|
||||
metadata["chat_id"], title
|
||||
)
|
||||
|
||||
# 通过 WebSocket 发送标题
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:title",
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
# 如果只有 2 条消息(首次对话),直接用用户消息作为标题
|
||||
elif len(messages) == 2:
|
||||
title = messages[0].get("content", user_message)
|
||||
|
||||
|
|
@ -1615,6 +1723,7 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
# 任务 3: 标签生成
|
||||
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
|
||||
res = await generate_chat_tags(
|
||||
request,
|
||||
|
|
@ -1638,16 +1747,19 @@ async def process_chat_response(
|
|||
else:
|
||||
tags_string = ""
|
||||
|
||||
# 提取 JSON 对象
|
||||
tags_string = tags_string[
|
||||
tags_string.find("{") : tags_string.rfind("}") + 1
|
||||
]
|
||||
|
||||
try:
|
||||
tags = json.loads(tags_string).get("tags", [])
|
||||
# 更新数据库
|
||||
Chats.update_chat_tags_by_id(
|
||||
metadata["chat_id"], tags, user
|
||||
)
|
||||
|
||||
# 通过 WebSocket 发送标签
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:tags",
|
||||
|
|
@ -1657,6 +1769,7 @@ async def process_chat_response(
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
# === 1. 解析事件发射器(仅在有实时会话时使用)===
|
||||
event_emitter = None
|
||||
event_caller = None
|
||||
if (
|
||||
|
|
@ -1667,18 +1780,19 @@ async def process_chat_response(
|
|||
and "message_id" in metadata
|
||||
and metadata["message_id"]
|
||||
):
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_caller = get_event_call(metadata)
|
||||
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
|
||||
event_caller = get_event_call(metadata) # 事件调用器
|
||||
|
||||
# Non-streaming response
|
||||
# === 2. 非流式响应处理 ===
|
||||
if not isinstance(response, StreamingResponse):
|
||||
if event_emitter:
|
||||
try:
|
||||
if isinstance(response, dict) or isinstance(response, JSONResponse):
|
||||
# 处理单项列表(解包)
|
||||
if isinstance(response, list) and len(response) == 1:
|
||||
# If the response is a single-item list, unwrap it #17213
|
||||
response = response[0]
|
||||
|
||||
# 解析 JSONResponse 的 body
|
||||
if isinstance(response, JSONResponse) and isinstance(
|
||||
response.body, bytes
|
||||
):
|
||||
|
|
@ -1693,6 +1807,7 @@ async def process_chat_response(
|
|||
else:
|
||||
response_data = response
|
||||
|
||||
# 处理错误响应
|
||||
if "error" in response_data:
|
||||
error = response_data.get("error")
|
||||
|
||||
|
|
@ -1701,6 +1816,7 @@ async def process_chat_response(
|
|||
else:
|
||||
error = str(error)
|
||||
|
||||
# 保存错误到数据库
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
|
|
@ -1708,6 +1824,7 @@ async def process_chat_response(
|
|||
"error": {"content": error},
|
||||
},
|
||||
)
|
||||
# 通过 WebSocket 发送错误事件
|
||||
if isinstance(error, str) or isinstance(error, dict):
|
||||
await event_emitter(
|
||||
{
|
||||
|
|
@ -1821,7 +1938,7 @@ async def process_chat_response(
|
|||
|
||||
return response
|
||||
|
||||
# Non standard response
|
||||
# Non standard response (not SSE/ndjson)
|
||||
if not any(
|
||||
content_type in response.headers["Content-Type"]
|
||||
for content_type in ["text/event-stream", "application/x-ndjson"]
|
||||
|
|
@ -1854,7 +1971,7 @@ async def process_chat_response(
|
|||
)
|
||||
]
|
||||
|
||||
# Streaming response
|
||||
# Streaming response: consume upstream SSE/ndjson, forward deltas/events, handle tools
|
||||
if event_emitter and event_caller:
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
model_id = form_data.get("model", "")
|
||||
|
|
@ -2289,26 +2406,49 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
async def stream_body_handler(response, form_data):
|
||||
nonlocal content
|
||||
nonlocal content_blocks
|
||||
"""
|
||||
流式响应体处理器 - 消费上游 SSE 流并处理内容块
|
||||
|
||||
response_tool_calls = []
|
||||
核心功能:
|
||||
1. 解析 SSE (Server-Sent Events) 流
|
||||
2. 处理工具调用(tool_calls)增量更新
|
||||
3. 处理推理内容(reasoning_content)- 如 <think> 标签
|
||||
4. 处理解决方案(solution)和代码解释器(code_interpreter)标签
|
||||
5. 实时保存消息到数据库(可选)
|
||||
6. 控制流式推送频率(delta throttling)避免 WebSocket 过载
|
||||
|
||||
delta_count = 0
|
||||
Args:
|
||||
response: 上游 StreamingResponse 对象
|
||||
form_data: 原始请求数据
|
||||
"""
|
||||
nonlocal content # 累积的完整文本内容
|
||||
nonlocal content_blocks # 内容块列表(text/reasoning/code_interpreter)
|
||||
|
||||
response_tool_calls = [] # 累积的工具调用列表
|
||||
|
||||
# === 1. 初始化流式推送控制 ===
|
||||
delta_count = 0 # 当前累积的 delta 数量
|
||||
delta_chunk_size = max(
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, # 全局配置的块大小
|
||||
int(
|
||||
metadata.get("params", {}).get("stream_delta_chunk_size")
|
||||
or 1
|
||||
),
|
||||
), # 用户配置的块大小
|
||||
)
|
||||
last_delta_data = None
|
||||
last_delta_data = None # 待发送的最后一个 delta 数据
|
||||
|
||||
async def flush_pending_delta_data(threshold: int = 0):
|
||||
"""
|
||||
刷新待发送的 delta 数据
|
||||
|
||||
Args:
|
||||
threshold: 阈值,当 delta_count >= threshold 时才发送
|
||||
"""
|
||||
nonlocal delta_count
|
||||
nonlocal last_delta_data
|
||||
|
||||
if delta_count >= threshold and last_delta_data:
|
||||
# 通过 WebSocket 发送累积的 delta
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
|
|
@ -2318,7 +2458,9 @@ async def process_chat_response(
|
|||
delta_count = 0
|
||||
last_delta_data = None
|
||||
|
||||
# === 2. 消费 SSE 流 ===
|
||||
async for line in response.body_iterator:
|
||||
# 解码字节流为字符串
|
||||
line = (
|
||||
line.decode("utf-8", "replace")
|
||||
if isinstance(line, bytes)
|
||||
|
|
@ -2326,20 +2468,22 @@ async def process_chat_response(
|
|||
)
|
||||
data = line
|
||||
|
||||
# Skip empty lines
|
||||
# 跳过空行
|
||||
if not data.strip():
|
||||
continue
|
||||
|
||||
# "data:" is the prefix for each event
|
||||
# SSE 格式:每个事件以 "data:" 开头
|
||||
if not data.startswith("data:"):
|
||||
continue
|
||||
|
||||
# Remove the prefix
|
||||
# 移除 "data:" 前缀
|
||||
data = data[len("data:") :].strip()
|
||||
|
||||
try:
|
||||
# 解析 JSON 数据
|
||||
data = json.loads(data)
|
||||
|
||||
# === 3. 执行 Filter 函数(stream 类型)===
|
||||
data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_functions=filter_functions,
|
||||
|
|
@ -2349,11 +2493,14 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
if data:
|
||||
# 处理自定义事件
|
||||
if "event" in data:
|
||||
await event_emitter(data.get("event", {}))
|
||||
|
||||
# === 4. 处理 Arena 模式的模型选择 ===
|
||||
if "selected_model_id" in data:
|
||||
model_id = data["selected_model_id"]
|
||||
# 保存选中的模型 ID 到数据库
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
|
|
@ -2370,9 +2517,9 @@ async def process_chat_response(
|
|||
else:
|
||||
choices = data.get("choices", [])
|
||||
|
||||
# 17421
|
||||
# === 5. 处理 usage 和 timings 信息 ===
|
||||
usage = data.get("usage", {}) or {}
|
||||
usage.update(data.get("timings", {})) # llama.cpp
|
||||
usage.update(data.get("timings", {})) # llama.cpp timings
|
||||
if usage:
|
||||
await event_emitter(
|
||||
{
|
||||
|
|
@ -2383,6 +2530,7 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
# === 6. 处理错误响应 ===
|
||||
if not choices:
|
||||
error = data.get("error", {})
|
||||
if error:
|
||||
|
|
@ -2396,9 +2544,11 @@ async def process_chat_response(
|
|||
)
|
||||
continue
|
||||
|
||||
# === 7. 提取 delta(增量内容)===
|
||||
delta = choices[0].get("delta", {})
|
||||
delta_tool_calls = delta.get("tool_calls", None)
|
||||
|
||||
# === 8. 处理工具调用(Tool Calls)===
|
||||
if delta_tool_calls:
|
||||
for delta_tool_call in delta_tool_calls:
|
||||
tool_call_index = delta_tool_call.get(
|
||||
|
|
@ -2406,7 +2556,7 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
if tool_call_index is not None:
|
||||
# Check if the tool call already exists
|
||||
# 查找已存在的工具调用
|
||||
current_response_tool_call = None
|
||||
for (
|
||||
response_tool_call
|
||||
|
|
@ -2421,7 +2571,7 @@ async def process_chat_response(
|
|||
break
|
||||
|
||||
if current_response_tool_call is None:
|
||||
# Add the new tool call
|
||||
# 添加新的工具调用
|
||||
delta_tool_call.setdefault(
|
||||
"function", {}
|
||||
)
|
||||
|
|
@ -2435,7 +2585,7 @@ async def process_chat_response(
|
|||
delta_tool_call
|
||||
)
|
||||
else:
|
||||
# Update the existing tool call
|
||||
# 更新已存在的工具调用(累积 name 和 arguments)
|
||||
delta_name = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("name")
|
||||
|
|
@ -2457,14 +2607,17 @@ async def process_chat_response(
|
|||
"arguments"
|
||||
] += delta_arguments
|
||||
|
||||
# === 9. 处理文本内容 ===
|
||||
value = delta.get("content")
|
||||
|
||||
# === 10. 处理推理内容(Reasoning Content)===
|
||||
reasoning_content = (
|
||||
delta.get("reasoning_content")
|
||||
or delta.get("reasoning")
|
||||
or delta.get("thinking")
|
||||
)
|
||||
if reasoning_content:
|
||||
# 创建或更新 reasoning 内容块
|
||||
if (
|
||||
not content_blocks
|
||||
or content_blocks[-1]["type"] != "reasoning"
|
||||
|
|
@ -2483,6 +2636,7 @@ async def process_chat_response(
|
|||
else:
|
||||
reasoning_block = content_blocks[-1]
|
||||
|
||||
# 累积推理内容
|
||||
reasoning_block["content"] += reasoning_content
|
||||
|
||||
data = {
|
||||
|
|
@ -2491,7 +2645,9 @@ async def process_chat_response(
|
|||
)
|
||||
}
|
||||
|
||||
# === 11. 处理普通文本内容 ===
|
||||
if value:
|
||||
# 如果上一个块是 reasoning,标记结束并创建新的文本块
|
||||
if (
|
||||
content_blocks
|
||||
and content_blocks[-1]["type"]
|
||||
|
|
@ -2515,6 +2671,7 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
# 累积文本内容
|
||||
content = f"{content}{value}"
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
|
|
@ -2528,6 +2685,8 @@ async def process_chat_response(
|
|||
content_blocks[-1]["content"] + value
|
||||
)
|
||||
|
||||
# === 12. 检测并处理特殊标签 ===
|
||||
# 12.1 Reasoning 标签检测(如 <think>)
|
||||
if DETECT_REASONING_TAGS:
|
||||
content, content_blocks, _ = (
|
||||
tag_content_handler(
|
||||
|
|
@ -2538,6 +2697,7 @@ async def process_chat_response(
|
|||
)
|
||||
)
|
||||
|
||||
# Solution 标签检测
|
||||
content, content_blocks, _ = (
|
||||
tag_content_handler(
|
||||
"solution",
|
||||
|
|
@ -2547,6 +2707,7 @@ async def process_chat_response(
|
|||
)
|
||||
)
|
||||
|
||||
# 12.2 Code Interpreter 标签检测
|
||||
if DETECT_CODE_INTERPRETER:
|
||||
content, content_blocks, end = (
|
||||
tag_content_handler(
|
||||
|
|
@ -2557,11 +2718,13 @@ async def process_chat_response(
|
|||
)
|
||||
)
|
||||
|
||||
# 如果检测到结束标签,停止流式处理
|
||||
if end:
|
||||
break
|
||||
|
||||
# === 13. 实时保存消息(可选)===
|
||||
if ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
# 保存到数据库
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
|
|
@ -2572,18 +2735,22 @@ async def process_chat_response(
|
|||
},
|
||||
)
|
||||
else:
|
||||
# 准备待发送的数据
|
||||
data = {
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks
|
||||
),
|
||||
}
|
||||
|
||||
# === 14. 流式推送控制 ===
|
||||
if delta:
|
||||
delta_count += 1
|
||||
last_delta_data = data
|
||||
# 达到阈值时刷新
|
||||
if delta_count >= delta_chunk_size:
|
||||
await flush_pending_delta_data(delta_chunk_size)
|
||||
else:
|
||||
# 非 delta 数据立即发送
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
|
|
@ -2591,24 +2758,30 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# 处理流结束标记
|
||||
done = "data: [DONE]" in line
|
||||
if done:
|
||||
pass
|
||||
else:
|
||||
log.debug(f"Error: {e}")
|
||||
continue
|
||||
|
||||
# === 15. 刷新剩余的 delta 数据 ===
|
||||
await flush_pending_delta_data()
|
||||
|
||||
# === 16. 清理内容块 ===
|
||||
if content_blocks:
|
||||
# Clean up the last text block
|
||||
# 清理最后一个文本块(移除尾部空白)
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].strip()
|
||||
|
||||
# 如果为空则移除
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
|
||||
# 确保至少有一个空文本块
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
{
|
||||
|
|
@ -2617,6 +2790,7 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
# 标记最后一个 reasoning 块结束
|
||||
if content_blocks[-1]["type"] == "reasoning":
|
||||
reasoning_block = content_blocks[-1]
|
||||
if reasoning_block.get("ended_at") is None:
|
||||
|
|
@ -2626,9 +2800,11 @@ async def process_chat_response(
|
|||
- reasoning_block["started_at"]
|
||||
)
|
||||
|
||||
# === 17. 保存工具调用 ===
|
||||
if response_tool_calls:
|
||||
tool_calls.append(response_tool_calls)
|
||||
|
||||
# === 18. 执行响应清理(关闭连接)===
|
||||
if response.background:
|
||||
await response.background()
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,45 @@ def get_messages_content(messages: list[dict]) -> str:
|
|||
)
|
||||
|
||||
|
||||
def extract_timestamped_messages(raw_msgs: list[dict]) -> list[dict]:
|
||||
"""
|
||||
将消息列表转换为统一的字典结构,便于下游持久化/审计。
|
||||
|
||||
Args:
|
||||
raw_msgs (list[dict]): OpenAI 格式的消息列表。
|
||||
|
||||
Returns:
|
||||
list[dict]: 每条消息包含 role、content、timestamp 字段。
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
for msg in raw_msgs:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
ts = (
|
||||
msg.get("createdAt")
|
||||
or msg.get("created_at")
|
||||
or msg.get("timestamp")
|
||||
or msg.get("updated_at")
|
||||
or msg.get("updatedAt")
|
||||
or 0
|
||||
)
|
||||
content = msg.get("content", "") or ""
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
content = item.get("text", "")
|
||||
break
|
||||
messages.append(
|
||||
{
|
||||
"role": msg.get("role", "assistant"),
|
||||
"content": str(content),
|
||||
"timestamp": int(ts),
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user":
|
||||
|
|
|
|||
Loading…
Reference in a new issue