import time
import logging
import sys
import os
import base64
import textwrap
import asyncio
from aiocache import cached
from typing import Any, Optional
import random
import json
import html
import inspect
import re
import ast
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
from fastapi import Request, HTTPException
from fastapi.responses import HTMLResponse
from starlette.responses import Response, StreamingResponse, JSONResponse
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.chats import Chats
from open_webui.models.folders import Folders
from open_webui.models.users import Users
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
get_active_status_by_user_id,
)
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_follow_ups,
generate_image_prompt,
generate_chat_tags,
)
from open_webui.routers.retrieval import (
process_web_search,
SearchForm,
)
from open_webui.routers.images import (
load_b64_image_data,
image_generations,
GenerateImageForm,
upload_image,
)
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
)
from open_webui.routers.memories import query_memory, QueryMemoryForm
from open_webui.utils.webhook import post_webhook
from open_webui.utils.files import (
get_audio_url_from_base64,
get_file_url_from_base64,
get_image_url_from_base64,
)
from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.retrieval.utils import get_sources_from_items
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
get_task_model_id,
rag_template,
tools_function_calling_generation_template,
)
from open_webui.utils.misc import (
deep_update,
extract_urls,
get_message_list,
add_or_update_system_message,
add_or_update_user_message,
get_last_user_message,
get_last_user_message_item,
get_last_assistant_message,
get_system_message,
prepend_to_first_user_message_content,
convert_logit_bias_input_to_json,
get_content_from_message,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.utils.payload import apply_system_prompt_to_body
from open_webui.utils.mcp.client import MCPClient
from open_webui.config import (
CACHE_DIR,
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
DEFAULT_CODE_INTERPRETER_PROMPT,
CODE_INTERPRETER_BLOCKED_MODULES,
)
from open_webui.env import (
SRC_LOG_LEVELS,
GLOBAL_LOG_LEVEL,
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES,
BYPASS_MODEL_ACCESS_CONTROL,
ENABLE_REALTIME_CHAT_SAVE,
ENABLE_QUERIES_CACHE,
)
from open_webui.constants import TASKS
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
DEFAULT_REASONING_TAGS = [
("", ""),
("", ""),
("", ""),
("", ""),
("", ""),
("", ""),
("<|begin_of_thought|>", "<|end_of_thought|>"),
("◁think▷", "◁/think▷"),
]
DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")]
DEFAULT_CODE_INTERPRETER_TAGS = [("", "")]
def process_tool_result(
request,
tool_function_name,
tool_result,
tool_type,
direct_tool=False,
metadata=None,
user=None,
):
tool_result_embeds = []
if isinstance(tool_result, HTMLResponse):
content_disposition = tool_result.headers.get("Content-Disposition", "")
if "inline" in content_disposition:
content = tool_result.body.decode("utf-8", "replace")
tool_result_embeds.append(content)
if 200 <= tool_result.status_code < 300:
tool_result = {
"status": "success",
"code": "ui_component",
"message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
}
elif 400 <= tool_result.status_code < 500:
tool_result = {
"status": "error",
"code": "ui_component",
"message": f"{tool_function_name}: Client error {tool_result.status_code} from embedded UI result.",
}
elif 500 <= tool_result.status_code < 600:
tool_result = {
"status": "error",
"code": "ui_component",
"message": f"{tool_function_name}: Server error {tool_result.status_code} from embedded UI result.",
}
else:
tool_result = {
"status": "error",
"code": "ui_component",
"message": f"{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.",
}
else:
tool_result = tool_result.body.decode("utf-8", "replace")
elif (tool_type == "external" and isinstance(tool_result, tuple)) or (
direct_tool and isinstance(tool_result, list) and len(tool_result) == 2
):
tool_result, tool_response_headers = tool_result
try:
if not isinstance(tool_response_headers, dict):
tool_response_headers = dict(tool_response_headers)
except Exception as e:
tool_response_headers = {}
log.debug(e)
if tool_response_headers and isinstance(tool_response_headers, dict):
content_disposition = tool_response_headers.get(
"Content-Disposition",
tool_response_headers.get("content-disposition", ""),
)
if "inline" in content_disposition:
content_type = tool_response_headers.get(
"Content-Type",
tool_response_headers.get("content-type", ""),
)
location = tool_response_headers.get(
"Location",
tool_response_headers.get("location", ""),
)
if "text/html" in content_type:
# Display as iframe embed
tool_result_embeds.append(tool_result)
tool_result = {
"status": "success",
"code": "ui_component",
"message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
}
elif location:
tool_result_embeds.append(location)
tool_result = {
"status": "success",
"code": "ui_component",
"message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
}
tool_result_files = []
if isinstance(tool_result, list):
if tool_type == "mcp": # MCP
tool_response = []
for item in tool_result:
if isinstance(item, dict):
if item.get("type") == "text":
text = item.get("text", "")
if isinstance(text, str):
try:
text = json.loads(text)
except json.JSONDecodeError:
pass
tool_response.append(text)
elif item.get("type") in ["image", "audio"]:
file_url = get_file_url_from_base64(
request,
f"data:{item.get('mimeType')};base64,{item.get('data', item.get('blob', ''))}",
{
"chat_id": metadata.get("chat_id", None),
"message_id": metadata.get("message_id", None),
"session_id": metadata.get("session_id", None),
"result": item,
},
user,
)
tool_result_files.append(
{
"type": item.get("type", "data"),
"url": file_url,
}
)
tool_result = tool_response[0] if len(tool_response) == 1 else tool_response
else: # OpenAPI
for item in tool_result:
if isinstance(item, str) and item.startswith("data:"):
tool_result_files.append(
{
"type": "data",
"content": item,
}
)
tool_result.remove(item)
if isinstance(tool_result, list):
tool_result = {"results": tool_result}
if isinstance(tool_result, dict) or isinstance(tool_result, list):
tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False)
return tool_result, tool_result_files, tool_result_embeds
async def chat_completion_tools_handler(
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8", "replace"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
def get_tools_function_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
recent_messages = messages[-4:] if len(messages) > 4 else messages
chat_history = "\n".join(
f"{message['role'].upper()}: \"\"\"{get_content_from_message(message)}\"\"\""
for message in recent_messages
)
prompt = f"History:\n{chat_history}\nQuery: {user_message}"
return {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
event_caller = extra_params["__event_call__"]
event_emitter = extra_params["__event_emitter__"]
metadata = extra_params["__metadata__"]
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
skip_files = False
sources = []
specs = [tool["spec"] for tool in tools.values()]
tools_specs = json.dumps(specs)
if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
else:
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
tools_function_calling_prompt = tools_function_calling_generation_template(
template, tools_specs
)
payload = get_tools_function_calling_payload(
body["messages"], task_model_id, tools_function_calling_prompt
)
try:
response = await generate_chat_completion(request, form_data=payload, user=user)
log.debug(f"{response=}")
content = await get_content_from_response(response)
log.debug(f"{content=}")
if not content:
return body, {}
try:
content = content[content.find("{") : content.rfind("}") + 1]
if not content:
raise Exception("No JSON object found in the response")
result = json.loads(content)
async def tool_call_handler(tool_call):
nonlocal skip_files
log.debug(f"{tool_call=}")
tool_function_name = tool_call.get("name", None)
if tool_function_name not in tools:
return body, {}
tool_function_params = tool_call.get("parameters", {})
tool = None
tool_type = ""
direct_tool = False
try:
tool = tools[tool_function_name]
tool_type = tool.get("type", "")
direct_tool = tool.get("direct", False)
spec = tool.get("spec", {})
allowed_params = (
spec.get("parameters", {}).get("properties", {}).keys()
)
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in allowed_params
}
if tool.get("direct", False):
tool_result = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"name": tool_function_name,
"params": tool_function_params,
"server": tool.get("server", {}),
"session_id": metadata.get("session_id", None),
},
}
)
else:
tool_function = tool["callable"]
tool_result = await tool_function(**tool_function_params)
except Exception as e:
tool_result = str(e)
tool_result, tool_result_files, tool_result_embeds = (
process_tool_result(
request,
tool_function_name,
tool_result,
tool_type,
direct_tool,
metadata,
user,
)
)
if event_emitter:
if tool_result_files:
await event_emitter(
{
"type": "files",
"data": {
"files": tool_result_files,
},
}
)
if tool_result_embeds:
await event_emitter(
{
"type": "embeds",
"data": {
"embeds": tool_result_embeds,
},
}
)
print(
f"Tool {tool_function_name} result: {tool_result}",
tool_result_files,
tool_result_embeds,
)
if tool_result:
tool = tools[tool_function_name]
tool_id = tool.get("tool_id", "")
tool_name = (
f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
)
# Citation is enabled for this tool
sources.append(
{
"source": {
"name": (f"{tool_name}"),
},
"document": [str(tool_result)],
"metadata": [
{
"source": (f"{tool_name}"),
"parameters": tool_function_params,
}
],
"tool_result": True,
}
)
# Citation is not enabled for this tool
body["messages"] = add_or_update_user_message(
f"\nTool `{tool_name}` Output: {tool_result}",
body["messages"],
)
if (
tools[tool_function_name]
.get("metadata", {})
.get("file_handler", False)
):
skip_files = True
# check if "tool_calls" in result
if result.get("tool_calls"):
for tool_call in result.get("tool_calls"):
await tool_call_handler(tool_call)
else:
await tool_call_handler(result)
except Exception as e:
log.debug(f"Error: {e}")
content = None
except Exception as e:
log.debug(f"Error: {e}")
content = None
log.debug(f"tool_contexts: {sources}")
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {"sources": sources}
async def chat_memory_handler(
request: Request, form_data: dict, extra_params: dict, user
):
try:
results = await query_memory(
request,
QueryMemoryForm(
**{
"content": get_last_user_message(form_data["messages"]) or "",
"k": 3,
}
),
user,
)
except Exception as e:
log.debug(e)
results = None
user_context = ""
if results and hasattr(results, "documents"):
if results.documents and len(results.documents) > 0:
for doc_idx, doc in enumerate(results.documents[0]):
created_at_date = "Unknown Date"
if results.metadatas[0][doc_idx].get("created_at"):
created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
created_at_date = time.strftime(
"%Y-%m-%d", time.localtime(created_at_timestamp)
)
user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
form_data["messages"] = add_or_update_system_message(
f"User Context:\n{user_context}\n", form_data["messages"], append=True
)
return form_data
async def chat_web_search_handler(
request: Request, form_data: dict, extra_params: dict, user
):
event_emitter = extra_params["__event_emitter__"]
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searching the web",
"done": False,
},
}
)
messages = form_data["messages"]
user_message = get_last_user_message(messages)
queries = []
try:
res = await generate_queries(
request,
{
"model": form_data["model"],
"messages": messages,
"prompt": user_message,
"type": "web_search",
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
queries = json.loads(response)
queries = queries.get("queries", [])
except Exception as e:
queries = [response]
if ENABLE_QUERIES_CACHE:
request.state.cached_queries = queries
except Exception as e:
log.exception(e)
queries = [user_message]
# Check if generated queries are empty
if len(queries) == 1 and queries[0].strip() == "":
queries = [user_message]
# Check if queries are not found
if len(queries) == 0:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search query generated",
"done": True,
},
}
)
return form_data
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search_queries_generated",
"queries": queries,
"done": False,
},
}
)
try:
results = await process_web_search(
request,
SearchForm(queries=queries),
user=user,
)
if results:
files = form_data.get("files", [])
if results.get("collection_names"):
for col_idx, collection_name in enumerate(
results.get("collection_names")
):
files.append(
{
"collection_name": collection_name,
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
"queries": queries,
}
)
elif results.get("docs"):
# Invoked when bypass embedding and retrieval is set to True
docs = results["docs"]
files.append(
{
"docs": docs,
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
"queries": queries,
}
)
form_data["files"] = files
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
"urls": results["filenames"],
"items": results.get("items", []),
"done": True,
},
}
)
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"done": True,
"error": True,
},
}
)
except Exception as e:
log.exception(e)
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "An error occurred while searching the web",
"queries": queries,
"done": True,
"error": True,
},
}
)
return form_data
async def chat_image_generation_handler(
request: Request, form_data: dict, extra_params: dict, user
):
__event_emitter__ = extra_params["__event_emitter__"]
await __event_emitter__(
{
"type": "status",
"data": {"description": "Creating image", "done": False},
}
)
messages = form_data["messages"]
user_message = get_last_user_message(messages)
prompt = user_message
negative_prompt = ""
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": messages,
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
prompt = user_message
except Exception as e:
log.exception(e)
prompt = user_message
system_message_content = ""
try:
images = await image_generations(
request=request,
form_data=GenerateImageForm(**{"prompt": prompt}),
user=user,
)
await __event_emitter__(
{
"type": "status",
"data": {"description": "Image created", "done": True},
}
)
await __event_emitter__(
{
"type": "files",
"data": {
"files": [
{
"type": "image",
"url": image["url"],
}
for image in images
]
},
}
)
system_message_content = "User is shown the generated image, tell the user that the image has been generated"
except Exception as e:
log.exception(e)
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"An error occurred while generating an image",
"done": True,
},
}
)
system_message_content = "Unable to generate an image, tell the user that an error occurred"
if system_message_content:
form_data["messages"] = add_or_update_system_message(
system_message_content, form_data["messages"]
)
return form_data
async def chat_completion_files_handler(
request: Request, body: dict, extra_params: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
__event_emitter__ = extra_params["__event_emitter__"]
sources = []
if files := body.get("metadata", {}).get("files", None):
# Check if all files are in full context mode
all_full_context = all(item.get("context") == "full" for item in files)
queries = []
if not all_full_context:
try:
queries_response = await generate_queries(
request,
{
"model": body["model"],
"messages": body["messages"],
"type": "retrieval",
},
user,
)
queries_response = queries_response["choices"][0]["message"]["content"]
try:
bracket_start = queries_response.find("{")
bracket_end = queries_response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
queries_response = queries_response[bracket_start:bracket_end]
queries_response = json.loads(queries_response)
except Exception as e:
queries_response = {"queries": [queries_response]}
queries = queries_response.get("queries", [])
except:
pass
await __event_emitter__(
{
"type": "status",
"data": {
"action": "queries_generated",
"queries": queries,
"done": False,
},
}
)
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
try:
# Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
sources = await loop.run_in_executor(
executor,
lambda: get_sources_from_items(
request=request,
items=files,
queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user
),
k=request.app.state.config.TOP_K,
reranking_function=(
(
lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
)
)
if request.app.state.RERANKING_FUNCTION
else None
),
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=all_full_context
or request.app.state.config.RAG_FULL_CONTEXT,
user=user,
),
)
except Exception as e:
log.exception(e)
log.debug(f"rag_contexts:sources: {sources}")
unique_ids = set()
for source in sources or []:
if not source or len(source.keys()) == 0:
continue
documents = source.get("document") or []
metadatas = source.get("metadata") or []
src_info = source.get("source") or {}
for index, _ in enumerate(documents):
metadata = metadatas[index] if index < len(metadatas) else None
_id = (
(metadata or {}).get("source")
or (src_info or {}).get("id")
or "N/A"
)
unique_ids.add(_id)
sources_count = len(unique_ids)
await __event_emitter__(
{
"type": "status",
"data": {
"action": "sources_retrieved",
"count": sources_count,
"done": True,
},
}
)
return body, {"sources": sources}
def apply_params_to_form_data(form_data, model):
params = form_data.pop("params", {})
custom_params = params.pop("custom_params", {})
open_webui_params = {
"stream_response": bool,
"stream_delta_chunk_size": int,
"function_calling": str,
"reasoning_tags": list,
"system": str,
}
for key in list(params.keys()):
if key in open_webui_params:
del params[key]
if custom_params:
# Attempt to parse custom_params if they are strings
for key, value in custom_params.items():
if isinstance(value, str):
try:
# Attempt to parse the string as JSON
custom_params[key] = json.loads(value)
except json.JSONDecodeError:
# If it fails, keep the original string
pass
# If custom_params are provided, merge them into params
params = deep_update(params, custom_params)
if model.get("owned_by") == "ollama":
# Ollama specific parameters
form_data["options"] = params
else:
if isinstance(params, dict):
for key, value in params.items():
if value is not None:
form_data[key] = value
if "logit_bias" in params and params["logit_bias"] is not None:
try:
form_data["logit_bias"] = json.loads(
convert_logit_bias_input_to_json(params["logit_bias"])
)
except Exception as e:
log.exception(f"Error parsing logit_bias: {e}")
return form_data
async def process_chat_payload(request, form_data, user, metadata, model):
"""
处理聊天请求的 Payload - 执行 Pipeline、Filter、功能增强和工具注入
这是聊天请求预处理的核心函数,按以下顺序执行:
1. Pipeline Inlet (管道入口) - 自定义 Python 插件预处理
2. Filter Inlet (过滤器入口) - 函数过滤器预处理
3. Chat Memory (记忆) - 注入历史对话记忆
4. Chat Web Search (网页搜索) - 执行网络搜索并注入结果
5. Chat Image Generation (图像生成) - 处理图像生成请求
6. Chat Code Interpreter (代码解释器) - 注入代码执行提示词
7. Chat Tools Function Calling (工具调用) - 处理函数/工具调用
8. Chat Files (文件处理) - 处理上传文件、知识库文件、RAG 检索
Args:
request: FastAPI Request 对象
form_data: OpenAI 格式的聊天请求数据
user: 用户对象
metadata: 元数据(chat_id, message_id, tool_ids, files 等)
model: 模型配置对象
Returns:
tuple: (form_data, metadata, events)
- form_data: 处理后的请求数据
- metadata: 更新后的元数据
- events: 需要发送给前端的事件列表(如引用来源)
"""
# === 1. 应用模型参数到请求 ===
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
# === 2. 处理 System Prompt 变量替换 ===
system_message = get_system_message(form_data.get("messages", []))
if system_message: # Chat Controls/User Settings
try:
# 替换 system prompt 中的变量(如 {{USER_NAME}}, {{CURRENT_DATE}})
form_data = apply_system_prompt_to_body(
system_message.get("content"), form_data, metadata, user, replace=True
)
except:
pass
# === 3. 初始化事件发射器和回调 ===
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
event_call = get_event_call(metadata) # 事件调用函数
# === 4. 获取 OAuth Token ===
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
# === 5. 构建额外参数(供 Pipeline/Filter/Tools 使用)===
extra_params = {
"__event_emitter__": event_emitter, # 用于向前端发送实时事件
"__event_call__": event_call, # 用于调用事件回调
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
"__oauth_token__": oauth_token,
}
# === 6. 确定模型列表和任务模型 ===
# Initialize events to store additional event to be sent to the client
# Initialize contexts and citation
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
# Direct 模式:使用用户直连的模型
models = {
request.state.model["id"]: request.state.model,
}
else:
# 标准模式:使用平台所有模型
models = request.app.state.MODELS
# 获取任务模型 ID(用于工具调用、标题生成等后台任务)
task_model_id = get_task_model_id(
form_data["model"],
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
# === 7. 初始化事件和引用来源列表 ===
events = [] # 需要发送给前端的事件(如 sources 引用)
sources = [] # RAG 检索到的文档来源
# === 8. Folder "Project" 处理 - 注入文件夹的 System Prompt 和文件 ===
# Check if the request has chat_id and is inside of a folder
chat_id = metadata.get("chat_id", None)
if False:
if chat_id and user:
chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
if chat and chat.folder_id:
folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
if folder and folder.data:
# 注入文件夹的 system prompt
if "system_prompt" in folder.data:
form_data = apply_system_prompt_to_body(
folder.data["system_prompt"], form_data, metadata, user
)
# 注入文件夹关联的文件
if "files" in folder.data:
form_data["files"] = [
*folder.data["files"],
*form_data.get("files", []),
]
# === 9. Model "Knowledge" 处理 - 注入模型绑定的知识库 ===
user_message = get_last_user_message(form_data["messages"])
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
if model_knowledge:
# 向前端发送知识库搜索状态
await event_emitter(
{
"type": "status",
"data": {
"action": "knowledge_search",
"query": user_message,
"done": False,
},
}
)
knowledge_files = []
for item in model_knowledge:
# 处理旧格式的 collection_name
if item.get("collection_name"):
knowledge_files.append(
{
"id": item.get("collection_name"),
"name": item.get("name"),
"legacy": True,
}
)
# 处理新格式的 collection_names(多个集合)
elif item.get("collection_names"):
knowledge_files.append(
{
"name": item.get("name"),
"type": "collection",
"collection_names": item.get("collection_names"),
"legacy": True,
}
)
else:
knowledge_files.append(item)
# 合并模型知识库文件和用户上传文件
files = form_data.get("files", [])
files.extend(knowledge_files)
form_data["files"] = files
variables = form_data.pop("variables", None)
# === 10. Pipeline Inlet 处理 - 执行自定义 Python 插件 ===
# Process the form_data through the pipeline
try:
form_data = await process_pipeline_inlet_filter(
request, form_data, user, models
)
except Exception as e:
raise e
# === 11. Filter Inlet 处理 - 执行函数过滤器 ===
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
form_data, flags = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"{e}")
# === 12. 功能增强处理 (Features) ===
features = form_data.pop("features", None)
if features:
# 12.1 记忆功能 - 注入历史对话记忆
if "memory" in features and features["memory"]:
form_data = await chat_memory_handler(
request, form_data, extra_params, user
)
# 12.2 网页搜索功能 - 执行网络搜索
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
)
# 12.3 图像生成功能 - 处理图像生成请求
if "image_generation" in features and features["image_generation"]:
form_data = await chat_image_generation_handler(
request, form_data, extra_params, user
)
# 12.4 代码解释器功能 - 注入代码执行提示词
if "code_interpreter" in features and features["code_interpreter"]:
form_data["messages"] = add_or_update_user_message(
(
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE
if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != ""
else DEFAULT_CODE_INTERPRETER_PROMPT
),
form_data["messages"],
)
# === 13. 提取工具和文件信息 ===
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# === 14. 文件夹类型文件展开 ===
prompt = get_last_user_message(form_data["messages"])
if files:
if not files:
files = []
for file_item in files:
# 如果文件类型是 folder,展开其包含的所有文件
if file_item.get("type", "file") == "folder":
# Get folder files
folder_id = file_item.get("id", None)
if folder_id:
folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id)
if folder and folder.data and "files" in folder.data:
# 移除文件夹项,添加文件夹内的文件
files = [f for f in files if f.get("id", None) != folder_id]
files = [*files, *folder.data["files"]]
# 去重文件(基于文件内容)
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
# === 15. 更新元数据 ===
metadata = {
**metadata,
"tool_ids": tool_ids,
"files": files,
}
form_data["metadata"] = metadata
# === 16. 准备工具字典 ===
# Server side tools
tool_ids = metadata.get("tool_ids", None) # 服务器端工具 ID
# Client side tools
direct_tool_servers = metadata.get("tool_servers", None) # 客户端直连工具服务器
log.debug(f"{tool_ids=}")
log.debug(f"{direct_tool_servers=}")
tools_dict = {} # 所有工具的字典
mcp_clients = {} # MCP (Model Context Protocol) 客户端
mcp_tools_dict = {} # MCP 工具字典
if tool_ids:
for tool_id in tool_ids:
# === 16.1 处理 MCP (Model Context Protocol) 工具 ===
if tool_id.startswith("server:mcp:"):
try:
server_id = tool_id[len("server:mcp:") :]
# 查找 MCP 服务器连接配置
mcp_server_connection = None
for (
server_connection
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if (
server_connection.get("type", "") == "mcp"
and server_connection.get("info", {}).get("id") == server_id
):
mcp_server_connection = server_connection
break
if not mcp_server_connection:
log.error(f"MCP server with id {server_id} not found")
continue
# 处理认证类型
auth_type = mcp_server_connection.get("auth_type", "")
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "none":
# 无需认证
pass
elif auth_type == "session":
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
elif auth_type == "oauth_2.1":
try:
splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
user.id, f"mcp:{server_id}"
)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
oauth_token = None
# 连接到 MCP 服务器
mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
# 获取 MCP 工具列表并注册
tool_specs = await mcp_clients[server_id].list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(client, function_name):
"""为每个 MCP 工具创建异步调用函数"""
async def tool_function(**kwargs):
return await client.call_tool(
function_name,
function_args=kwargs,
)
return tool_function
tool_function = make_tool_function(
mcp_clients[server_id], tool_spec["name"]
)
# 注册 MCP 工具到工具字典
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
"spec": {
**tool_spec,
"name": f"{server_id}_{tool_spec['name']}",
},
"callable": tool_function,
"type": "mcp",
"client": mcp_clients[server_id],
"direct": False,
}
except Exception as e:
log.debug(e)
continue
# === 16.2 获取标准工具(Function Tools)===
tools_dict = await get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": form_data["messages"],
"__files__": metadata.get("files", []),
},
)
# 合并 MCP 工具
if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict}
# === 16.3 处理客户端直连工具服务器 ===
if direct_tool_servers:
for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", [])
for tool in tool_specs:
tools_dict[tool["name"]] = {
"spec": tool,
"direct": True,
"server": tool_server,
}
# 保存 MCP 客户端到元数据(用于最后清理)
if mcp_clients:
metadata["mcp_clients"] = mcp_clients
# === 17. 工具调用处理 ===
if tools_dict:
if metadata.get("params", {}).get("function_calling") == "native":
# 原生函数调用模式:直接传递给 LLM
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values()
]
else:
# 默认模式:通过 Prompt 实现工具调用
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
# === 18. 文件处理 - RAG 检索 ===
try:
form_data, flags = await chat_completion_files_handler(
request, form_data, extra_params, user
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
# === 19. 构建上下文字符串并注入到消息 ===
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
citation_idx_map = {} # 引用索引映射(文档 ID → 引用编号)
# 遍历所有来源,构建上下文字符串
for source in sources:
if "document" in source:
for document_text, document_metadata in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
source_id = (
document_metadata.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
# 为每个来源分配唯一的引用编号
if source_id not in citation_idx_map:
citation_idx_map[source_id] = len(citation_idx_map) + 1
# 构建 XML 格式的来源标签
context_string += (
f'{document_text}\n"
)
context_string = context_string.strip()
if prompt is None:
raise Exception("No user message found")
# 使用 RAG 模板将上下文注入到用户消息中
if context_string != "":
form_data["messages"] = add_or_update_user_message(
rag_template(
request.app.state.config.RAG_TEMPLATE,
context_string,
prompt,
),
form_data["messages"],
append=False,
)
# === 20. 整理引用来源并添加到事件 ===
# If there are citations, add them to the data_items
sources = [
source
for source in sources
if source.get("source", {}).get("name", "")
or source.get("source", {}).get("id", "")
]
if len(sources) > 0:
events.append({"sources": sources})
# === 21. 完成知识库搜索状态 ===
if model_knowledge:
await event_emitter(
{
"type": "status",
"data": {
"action": "knowledge_search",
"query": user_message,
"done": True,
"hidden": True,
},
}
)
return form_data, metadata, events
async def process_chat_response(
request, response, form_data, user, metadata, model, events, tasks
):
"""
处理聊天响应 - 规范化/分发聊天完成响应
这是聊天响应后处理的核心函数,负责:
1. 处理流式(SSE/ndjson)和非流式(JSON)响应
2. 通过 WebSocket 发送事件到前端
3. 持久化消息/错误/元数据到数据库
4. 触发后台任务(标题生成、标签生成、Follow-ups)
5. 流式响应时:消费上游数据块、重建内容、处理工具调用、转发增量
Args:
request: FastAPI Request 对象
response: 上游响应(dict/JSONResponse/StreamingResponse)
form_data: 原始请求 payload(可能被上游修改)
user: 已验证的用户对象
metadata: 早期收集的聊天/会话上下文
model: 解析后的模型配置
events: 需要发送的额外事件(如 sources 引用)
tasks: 可选的后台任务(title/tags/follow-ups)
Returns:
- 非流式: dict (OpenAI JSON 格式)
- 流式: StreamingResponse (SSE/ndjson 格式)
"""
# === 内部函数:后台任务处理器 ===
async def background_tasks_handler():
"""
在响应完成后执行后台任务:
1. Follow-ups 生成 - 生成后续问题建议
2. Title 生成 - 自动生成聊天标题
3. Tags 生成 - 自动生成聊天标签
"""
message = None
messages = []
# 获取消息历史
if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"):
# 从数据库获取持久化的聊天历史
messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
message = messages_map.get(metadata["message_id"]) if messages_map else None
message_list = get_message_list(messages_map, metadata["message_id"])
# 清理消息内容:移除 details 标签和文件
# get_message_list 创建新列表,不影响原始消息
messages = []
for message in message_list:
content = message.get("content", "")
# 处理多模态内容(图片 + 文本)
if isinstance(content, list):
for item in content:
if item.get("type") == "text":
content = item["text"]
break
# 移除 标签和 Markdown 图片
if isinstance(content, str):
content = re.sub(
r"]*>.*?<\/details>|!\[.*?\]\(.*?\)",
"",
content,
flags=re.S | re.I,
).strip()
messages.append(
{
**message,
"role": message.get("role", "assistant"), # 安全回退
"content": content,
}
)
else:
# 临时聊天(local:):从 form_data 获取
message = get_last_user_message_item(form_data.get("messages", []))
messages = form_data.get("messages", [])
if message:
message["model"] = form_data.get("model")
# 执行后台任务
if message and "model" in message:
if tasks and messages:
# === 任务 1: Follow-ups 生成 ===
if (
TASKS.FOLLOW_UP_GENERATION in tasks
and tasks[TASKS.FOLLOW_UP_GENERATION]
):
res = await generate_follow_ups(
request,
{
"model": message["model"],
"messages": messages,
"message_id": metadata["message_id"],
"chat_id": metadata["chat_id"],
},
user,
)
if res and isinstance(res, dict):
if len(res.get("choices", [])) == 1:
response_message = res.get("choices", [])[0].get(
"message", {}
)
follow_ups_string = response_message.get(
"content"
) or response_message.get("reasoning_content", "")
else:
follow_ups_string = ""
# 提取 JSON 对象(从第一个 { 到最后一个 })
follow_ups_string = follow_ups_string[
follow_ups_string.find("{") : follow_ups_string.rfind("}")
+ 1
]
try:
follow_ups = json.loads(follow_ups_string).get(
"follow_ups", []
)
# 通过 WebSocket 发送 Follow-ups
await event_emitter(
{
"type": "chat:message:follow_ups",
"data": {
"follow_ups": follow_ups,
},
}
)
# 持久化到数据库
if not metadata.get("chat_id", "").startswith("local:"):
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"followUps": follow_ups,
},
)
except Exception as e:
pass
# === 任务 2 & 3: 标题和标签生成(仅非临时聊天)===
if not metadata.get("chat_id", "").startswith("local:"):
# 任务 2: 标题生成
if (
TASKS.TITLE_GENERATION in tasks
and tasks[TASKS.TITLE_GENERATION]
):
user_message = get_last_user_message(messages)
if user_message and len(user_message) > 100:
user_message = user_message[:100] + "..."
if tasks[TASKS.TITLE_GENERATION]:
res = await generate_title(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res and isinstance(res, dict):
if len(res.get("choices", [])) == 1:
response_message = res.get("choices", [])[0].get(
"message", {}
)
title_string = (
response_message.get("content")
or response_message.get(
"reasoning_content",
)
or message.get("content", user_message)
)
else:
title_string = ""
# 提取 JSON 对象
title_string = title_string[
title_string.find("{") : title_string.rfind("}") + 1
]
try:
title = json.loads(title_string).get(
"title", user_message
)
except Exception as e:
title = ""
if not title:
title = messages[0].get("content", user_message)
# 更新数据库
Chats.update_chat_title_by_id(
metadata["chat_id"], title
)
# 通过 WebSocket 发送标题
await event_emitter(
{
"type": "chat:title",
"data": title,
}
)
# 如果只有 2 条消息(首次对话),直接用用户消息作为标题
elif len(messages) == 2:
title = messages[0].get("content", user_message)
Chats.update_chat_title_by_id(metadata["chat_id"], title)
await event_emitter(
{
"type": "chat:title",
"data": message.get("content", user_message),
}
)
# 任务 3: 标签生成
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
res = await generate_chat_tags(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res and isinstance(res, dict):
if len(res.get("choices", [])) == 1:
response_message = res.get("choices", [])[0].get(
"message", {}
)
tags_string = response_message.get(
"content"
) or response_message.get("reasoning_content", "")
else:
tags_string = ""
# 提取 JSON 对象
tags_string = tags_string[
tags_string.find("{") : tags_string.rfind("}") + 1
]
try:
tags = json.loads(tags_string).get("tags", [])
# 更新数据库
Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user
)
# 通过 WebSocket 发送标签
await event_emitter(
{
"type": "chat:tags",
"data": tags,
}
)
except Exception as e:
pass
# === 1. 解析事件发射器(仅在有实时会话时使用)===
event_emitter = None
event_caller = None
if (
"session_id" in metadata
and metadata["session_id"]
and "chat_id" in metadata
and metadata["chat_id"]
and "message_id" in metadata
and metadata["message_id"]
):
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
event_caller = get_event_call(metadata) # 事件调用器
# === 2. 非流式响应处理 ===
if not isinstance(response, StreamingResponse):
if event_emitter:
try:
if isinstance(response, dict) or isinstance(response, JSONResponse):
# 处理单项列表(解包)
if isinstance(response, list) and len(response) == 1:
response = response[0]
# 解析 JSONResponse 的 body
if isinstance(response, JSONResponse) and isinstance(
response.body, bytes
):
try:
response_data = json.loads(
response.body.decode("utf-8", "replace")
)
except json.JSONDecodeError:
response_data = {
"error": {"detail": "Invalid JSON response"}
}
else:
response_data = response
# 处理错误响应
if "error" in response_data:
error = response_data.get("error")
if isinstance(error, dict):
error = error.get("detail", error)
else:
error = str(error)
# 保存错误到数据库
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"error": {"content": error},
},
)
# 通过 WebSocket 发送错误事件
if isinstance(error, str) or isinstance(error, dict):
await event_emitter(
{
"type": "chat:message:error",
"data": {"error": {"content": error}},
}
)
if "selected_model_id" in response_data:
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": response_data["selected_model_id"],
},
)
choices = response_data.get("choices", [])
if choices and choices[0].get("message", {}).get("content"):
content = response_data["choices"][0]["message"]["content"]
if content:
await event_emitter(
{
"type": "chat:completion",
"data": response_data,
}
)
title = Chats.get_chat_title_by_id(metadata["chat_id"])
await event_emitter(
{
"type": "chat:completion",
"data": {
"done": True,
"content": content,
"title": title,
},
}
)
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"role": "assistant",
"content": content,
},
)
# Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
await post_webhook(
request.app.state.WEBUI_NAME,
webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{
"action": "chat",
"message": content,
"title": title,
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
},
)
await background_tasks_handler()
if events and isinstance(events, list):
extra_response = {}
for event in events:
if isinstance(event, dict):
extra_response.update(event)
else:
extra_response[event] = True
response_data = {
**extra_response,
**response_data,
}
if isinstance(response, dict):
response = response_data
if isinstance(response, JSONResponse):
response = JSONResponse(
content=response_data,
headers=response.headers,
status_code=response.status_code,
)
except Exception as e:
log.debug(f"Error occurred while processing request: {e}")
pass
return response
else:
if events and isinstance(events, list) and isinstance(response, dict):
extra_response = {}
for event in events:
if isinstance(event, dict):
extra_response.update(event)
else:
extra_response[event] = True
response = {
**extra_response,
**response,
}
return response
# Non standard response (not SSE/ndjson)
if not any(
content_type in response.headers["Content-Type"]
for content_type in ["text/event-stream", "application/x-ndjson"]
):
return response
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__oauth_token__": oauth_token,
"__request__": request,
"__model__": model,
}
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
# Streaming response: consume upstream SSE/ndjson, forward deltas/events, handle tools
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
model_id = form_data.get("model", "")
def split_content_and_whitespace(content):
content_stripped = content.rstrip()
original_whitespace = (
content[len(content_stripped) :]
if len(content) > len(content_stripped)
else ""
)
return content_stripped, original_whitespace
def is_opening_code_block(content):
backtick_segments = content.split("```")
# Even number of segments means the last backticks are opening a new block
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
# Handle as a background task
async def response_handler(response, events):
def serialize_content_blocks(content_blocks, raw=False):
content = ""
for block in content_blocks:
if block["type"] == "text":
block_content = block["content"].strip()
if block_content:
content = f"{content}{block_content}\n"
elif block["type"] == "tool_calls":
attributes = block.get("attributes", {})
tool_calls = block.get("content", [])
results = block.get("results", [])
if content and not content.endswith("\n"):
content += "\n"
if results:
tool_calls_display_content = ""
for tool_call in tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get(
"name", ""
)
tool_arguments = tool_call.get("function", {}).get(
"arguments", ""
)
tool_result = None
tool_result_files = None
for result in results:
if tool_call_id == result.get("tool_call_id", ""):
tool_result = result.get("content", None)
tool_result_files = result.get("files", None)
break
if tool_result is not None:
tool_result_embeds = result.get("embeds", "")
tool_calls_display_content = f'{tool_calls_display_content}\nTool Executed
\n \n'
else:
tool_calls_display_content = f'{tool_calls_display_content}\nExecuting...
\n \n'
if not raw:
content = f"{content}{tool_calls_display_content}"
else:
tool_calls_display_content = ""
for tool_call in tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get(
"name", ""
)
tool_arguments = tool_call.get("function", {}).get(
"arguments", ""
)
tool_calls_display_content = f'{tool_calls_display_content}\n\nExecuting...
\n \n'
if not raw:
content = f"{content}{tool_calls_display_content}"
elif block["type"] == "reasoning":
reasoning_display_content = "\n".join(
(f"> {line}" if not line.startswith(">") else line)
for line in block["content"].splitlines()
)
reasoning_duration = block.get("duration", None)
start_tag = block.get("start_tag", "")
end_tag = block.get("end_tag", "")
if content and not content.endswith("\n"):
content += "\n"
if reasoning_duration is not None:
if raw:
content = (
f'{content}{start_tag}{block["content"]}{end_tag}\n'
)
else:
content = f'{content}\nThought for {reasoning_duration} seconds
\n{reasoning_display_content}\n \n'
else:
if raw:
content = (
f'{content}{start_tag}{block["content"]}{end_tag}\n'
)
else:
content = f'{content}\nThinking…
\n{reasoning_display_content}\n \n'
elif block["type"] == "code_interpreter":
attributes = block.get("attributes", {})
output = block.get("output", None)
lang = attributes.get("lang", "")
content_stripped, original_whitespace = (
split_content_and_whitespace(content)
)
if is_opening_code_block(content_stripped):
# Remove trailing backticks that would open a new block
content = (
content_stripped.rstrip("`").rstrip()
+ original_whitespace
)
else:
# Keep content as is - either closing backticks or no backticks
content = content_stripped + original_whitespace
if content and not content.endswith("\n"):
content += "\n"
if output:
output = html.escape(json.dumps(output))
if raw:
content = f'{content}\n{block["content"]}\n\n```output\n{output}\n```\n'
else:
content = f'{content}\nAnalyzed
\n```{lang}\n{block["content"]}\n```\n \n'
else:
if raw:
content = f'{content}\n{block["content"]}\n\n'
else:
content = f'{content}\nAnalyzing...
\n```{lang}\n{block["content"]}\n```\n \n'
else:
block_content = str(block["content"]).strip()
if block_content:
content = f"{content}{block['type']}: {block_content}\n"
return content.strip()
def convert_content_blocks_to_messages(content_blocks, raw=False):
messages = []
temp_blocks = []
for idx, block in enumerate(content_blocks):
if block["type"] == "tool_calls":
messages.append(
{
"role": "assistant",
"content": serialize_content_blocks(temp_blocks, raw),
"tool_calls": block.get("content"),
}
)
results = block.get("results", [])
for result in results:
messages.append(
{
"role": "tool",
"tool_call_id": result["tool_call_id"],
"content": result.get("content", "") or "",
}
)
temp_blocks = []
else:
temp_blocks.append(block)
if temp_blocks:
content = serialize_content_blocks(temp_blocks, raw)
if content:
messages.append(
{
"role": "assistant",
"content": content,
}
)
return messages
def tag_content_handler(content_type, tags, content, content_blocks):
end_flag = False
def extract_attributes(tag_content):
"""Extract attributes from a tag if they exist."""
attributes = {}
if not tag_content: # Ensure tag_content is not None
return attributes
# Match attributes in the format: key="value" (ignores single quotes for simplicity)
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
for key, value in matches:
attributes[key] = value
return attributes
if content_blocks[-1]["type"] == "text":
for start_tag, end_tag in tags:
start_tag_pattern = rf"{re.escape(start_tag)}"
if start_tag.startswith("<") and start_tag.endswith(">"):
# Match start tag e.g., or
# remove both '<' and '>' from start_tag
# Match start tag with attributes
start_tag_pattern = (
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
)
match = re.search(start_tag_pattern, content)
if match:
try:
attr_content = (
match.group(1) if match.group(1) else ""
) # Ensure it's not None
except:
attr_content = ""
attributes = extract_attributes(
attr_content
) # Extract attributes safely
# Capture everything before and after the matched tag
before_tag = content[
: match.start()
] # Content before opening tag
after_tag = content[
match.end() :
] # Content after opening tag
# Remove the start tag and after from the currently handling text block
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].replace(match.group(0) + after_tag, "")
if before_tag:
content_blocks[-1]["content"] = before_tag
if not content_blocks[-1]["content"]:
content_blocks.pop()
# Append the new block
content_blocks.append(
{
"type": content_type,
"start_tag": start_tag,
"end_tag": end_tag,
"attributes": attributes,
"content": "",
"started_at": time.time(),
}
)
if after_tag:
content_blocks[-1]["content"] = after_tag
tag_content_handler(
content_type, tags, after_tag, content_blocks
)
break
elif content_blocks[-1]["type"] == content_type:
start_tag = content_blocks[-1]["start_tag"]
end_tag = content_blocks[-1]["end_tag"]
if end_tag.startswith("<") and end_tag.endswith(">"):
# Match end tag e.g.,
end_tag_pattern = rf"{re.escape(end_tag)}"
else:
# Handle cases where end_tag is just a tag name
end_tag_pattern = rf"{re.escape(end_tag)}"
# Check if the content has the end tag
if re.search(end_tag_pattern, content):
end_flag = True
block_content = content_blocks[-1]["content"]
# Strip start and end tags from the content
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>"
block_content = re.sub(
start_tag_pattern, "", block_content
).strip()
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL)
split_content = end_tag_regex.split(block_content, maxsplit=1)
# Content inside the tag
block_content = (
split_content[0].strip() if split_content else ""
)
# Leftover content (everything after ``)
leftover_content = (
split_content[1].strip() if len(split_content) > 1 else ""
)
if block_content:
content_blocks[-1]["content"] = block_content
content_blocks[-1]["ended_at"] = time.time()
content_blocks[-1]["duration"] = int(
content_blocks[-1]["ended_at"]
- content_blocks[-1]["started_at"]
)
# Reset the content_blocks by appending a new text block
if content_type != "code_interpreter":
if leftover_content:
content_blocks.append(
{
"type": "text",
"content": leftover_content,
}
)
else:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
else:
# Remove the block if content is empty
content_blocks.pop()
if leftover_content:
content_blocks.append(
{
"type": "text",
"content": leftover_content,
}
)
else:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
# Clean processed content
start_tag_pattern = rf"{re.escape(start_tag)}"
if start_tag.startswith("<") and start_tag.endswith(">"):
# Match start tag e.g., or
# remove both '<' and '>' from start_tag
# Match start tag with attributes
start_tag_pattern = (
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
)
content = re.sub(
rf"{start_tag_pattern}(.|\n)*?{re.escape(end_tag)}",
"",
content,
flags=re.DOTALL,
)
return content, content_blocks, end_flag
message = Chats.get_message_by_id_and_message_id(
metadata["chat_id"], metadata["message_id"]
)
tool_calls = []
last_assistant_message = None
try:
if form_data["messages"][-1]["role"] == "assistant":
last_assistant_message = get_last_assistant_message(
form_data["messages"]
)
except Exception as e:
pass
content = (
message.get("content", "")
if message
else last_assistant_message if last_assistant_message else ""
)
content_blocks = [
{
"type": "text",
"content": content,
}
]
reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
DETECT_REASONING_TAGS = reasoning_tags_param is not False
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
"code_interpreter", False
)
reasoning_tags = []
if DETECT_REASONING_TAGS:
if (
isinstance(reasoning_tags_param, list)
and len(reasoning_tags_param) == 2
):
reasoning_tags = [
(reasoning_tags_param[0], reasoning_tags_param[1])
]
else:
reasoning_tags = DEFAULT_REASONING_TAGS
try:
for event in events:
await event_emitter(
{
"type": "chat:completion",
"data": event,
}
)
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
**event,
},
)
async def stream_body_handler(response, form_data):
"""
流式响应体处理器 - 消费上游 SSE 流并处理内容块
核心功能:
1. 解析 SSE (Server-Sent Events) 流
2. 处理工具调用(tool_calls)增量更新
3. 处理推理内容(reasoning_content)- 如 标签
4. 处理解决方案(solution)和代码解释器(code_interpreter)标签
5. 实时保存消息到数据库(可选)
6. 控制流式推送频率(delta throttling)避免 WebSocket 过载
Args:
response: 上游 StreamingResponse 对象
form_data: 原始请求数据
"""
nonlocal content # 累积的完整文本内容
nonlocal content_blocks # 内容块列表(text/reasoning/code_interpreter)
response_tool_calls = [] # 累积的工具调用列表
# === 1. 初始化流式推送控制 ===
delta_count = 0 # 当前累积的 delta 数量
delta_chunk_size = max(
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, # 全局配置的块大小
int(
metadata.get("params", {}).get("stream_delta_chunk_size")
or 1
), # 用户配置的块大小
)
last_delta_data = None # 待发送的最后一个 delta 数据
async def flush_pending_delta_data(threshold: int = 0):
"""
刷新待发送的 delta 数据
Args:
threshold: 阈值,当 delta_count >= threshold 时才发送
"""
nonlocal delta_count
nonlocal last_delta_data
if delta_count >= threshold and last_delta_data:
# 通过 WebSocket 发送累积的 delta
await event_emitter(
{
"type": "chat:completion",
"data": last_delta_data,
}
)
delta_count = 0
last_delta_data = None
# === 2. 消费 SSE 流 ===
async for line in response.body_iterator:
# 解码字节流为字符串
line = (
line.decode("utf-8", "replace")
if isinstance(line, bytes)
else line
)
data = line
# 跳过空行
if not data.strip():
continue
# SSE 格式:每个事件以 "data:" 开头
if not data.startswith("data:"):
continue
# 移除 "data:" 前缀
data = data[len("data:") :].strip()
try:
# 解析 JSON 数据
data = json.loads(data)
# === 3. 执行 Filter 函数(stream 类型)===
data, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="stream",
form_data=data,
extra_params={"__body__": form_data, **extra_params},
)
if data:
# 处理自定义事件
if "event" in data:
await event_emitter(data.get("event", {}))
# === 4. 处理 Arena 模式的模型选择 ===
if "selected_model_id" in data:
model_id = data["selected_model_id"]
# 保存选中的模型 ID 到数据库
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": model_id,
},
)
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
else:
choices = data.get("choices", [])
# === 5. 处理 usage 和 timings 信息 ===
usage = data.get("usage", {}) or {}
usage.update(data.get("timings", {})) # llama.cpp timings
if usage:
await event_emitter(
{
"type": "chat:completion",
"data": {
"usage": usage,
},
}
)
# === 6. 处理错误响应 ===
if not choices:
error = data.get("error", {})
if error:
await event_emitter(
{
"type": "chat:completion",
"data": {
"error": error,
},
}
)
continue
# === 7. 提取 delta(增量内容)===
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
# === 8. 处理工具调用(Tool Calls)===
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get(
"index"
)
if tool_call_index is not None:
# 查找已存在的工具调用
current_response_tool_call = None
for (
response_tool_call
) in response_tool_calls:
if (
response_tool_call.get("index")
== tool_call_index
):
current_response_tool_call = (
response_tool_call
)
break
if current_response_tool_call is None:
# 添加新的工具调用
delta_tool_call.setdefault(
"function", {}
)
delta_tool_call[
"function"
].setdefault("name", "")
delta_tool_call[
"function"
].setdefault("arguments", "")
response_tool_calls.append(
delta_tool_call
)
else:
# 更新已存在的工具调用(累积 name 和 arguments)
delta_name = delta_tool_call.get(
"function", {}
).get("name")
delta_arguments = (
delta_tool_call.get(
"function", {}
).get("arguments")
)
if delta_name:
current_response_tool_call[
"function"
]["name"] += delta_name
if delta_arguments:
current_response_tool_call[
"function"
][
"arguments"
] += delta_arguments
# === 9. 处理文本内容 ===
value = delta.get("content")
# === 10. 处理推理内容(Reasoning Content)===
reasoning_content = (
delta.get("reasoning_content")
or delta.get("reasoning")
or delta.get("thinking")
)
if reasoning_content:
# 创建或更新 reasoning 内容块
if (
not content_blocks
or content_blocks[-1]["type"] != "reasoning"
):
reasoning_block = {
"type": "reasoning",
"start_tag": "",
"end_tag": "",
"attributes": {
"type": "reasoning_content"
},
"content": "",
"started_at": time.time(),
}
content_blocks.append(reasoning_block)
else:
reasoning_block = content_blocks[-1]
# 累积推理内容
reasoning_block["content"] += reasoning_content
data = {
"content": serialize_content_blocks(
content_blocks
)
}
# === 11. 处理普通文本内容 ===
if value:
# 如果上一个块是 reasoning,标记结束并创建新的文本块
if (
content_blocks
and content_blocks[-1]["type"]
== "reasoning"
and content_blocks[-1]
.get("attributes", {})
.get("type")
== "reasoning_content"
):
reasoning_block = content_blocks[-1]
reasoning_block["ended_at"] = time.time()
reasoning_block["duration"] = int(
reasoning_block["ended_at"]
- reasoning_block["started_at"]
)
content_blocks.append(
{
"type": "text",
"content": "",
}
)
# 累积文本内容
content = f"{content}{value}"
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
content_blocks[-1]["content"] = (
content_blocks[-1]["content"] + value
)
# === 12. 检测并处理特殊标签 ===
# 12.1 Reasoning 标签检测(如 )
if DETECT_REASONING_TAGS:
content, content_blocks, _ = (
tag_content_handler(
"reasoning",
reasoning_tags,
content,
content_blocks,
)
)
# Solution 标签检测
content, content_blocks, _ = (
tag_content_handler(
"solution",
DEFAULT_SOLUTION_TAGS,
content,
content_blocks,
)
)
# 12.2 Code Interpreter 标签检测
if DETECT_CODE_INTERPRETER:
content, content_blocks, end = (
tag_content_handler(
"code_interpreter",
DEFAULT_CODE_INTERPRETER_TAGS,
content,
content_blocks,
)
)
# 如果检测到结束标签,停止流式处理
if end:
break
# === 13. 实时保存消息(可选)===
if ENABLE_REALTIME_CHAT_SAVE:
# 保存到数据库
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": serialize_content_blocks(
content_blocks
),
},
)
else:
# 准备待发送的数据
data = {
"content": serialize_content_blocks(
content_blocks
),
}
# === 14. 流式推送控制 ===
if delta:
delta_count += 1
last_delta_data = data
# 达到阈值时刷新
if delta_count >= delta_chunk_size:
await flush_pending_delta_data(delta_chunk_size)
else:
# 非 delta 数据立即发送
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
except Exception as e:
# 处理流结束标记
done = "data: [DONE]" in line
if done:
pass
else:
log.debug(f"Error: {e}")
continue
# === 15. 刷新剩余的 delta 数据 ===
await flush_pending_delta_data()
# === 16. 清理内容块 ===
if content_blocks:
# 清理最后一个文本块(移除尾部空白)
if content_blocks[-1]["type"] == "text":
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].strip()
# 如果为空则移除
if not content_blocks[-1]["content"]:
content_blocks.pop()
# 确保至少有一个空文本块
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
# 标记最后一个 reasoning 块结束
if content_blocks[-1]["type"] == "reasoning":
reasoning_block = content_blocks[-1]
if reasoning_block.get("ended_at") is None:
reasoning_block["ended_at"] = time.time()
reasoning_block["duration"] = int(
reasoning_block["ended_at"]
- reasoning_block["started_at"]
)
# === 17. 保存工具调用 ===
if response_tool_calls:
tool_calls.append(response_tool_calls)
# === 18. 执行响应清理(关闭连接)===
if response.background:
await response.background()
await stream_body_handler(response, form_data)
tool_call_retries = 0
while (
len(tool_calls) > 0
and tool_call_retries < CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES
):
tool_call_retries += 1
response_tool_calls = tool_calls.pop(0)
content_blocks.append(
{
"type": "tool_calls",
"content": response_tool_calls,
}
)
await event_emitter(
{
"type": "chat:completion",
"data": {
"content": serialize_content_blocks(content_blocks),
},
}
)
tools = metadata.get("tools", {})
results = []
for tool_call in response_tool_calls:
tool_call_id = tool_call.get("id", "")
tool_function_name = tool_call.get("function", {}).get(
"name", ""
)
tool_args = tool_call.get("function", {}).get("arguments", "{}")
tool_function_params = {}
try:
# json.loads cannot be used because some models do not produce valid JSON
tool_function_params = ast.literal_eval(tool_args)
except Exception as e:
log.debug(e)
# Fallback to JSON parsing
try:
tool_function_params = json.loads(tool_args)
except Exception as e:
log.error(
f"Error parsing tool call arguments: {tool_args}"
)
# Mutate the original tool call response params as they are passed back to the passed
# back to the LLM via the content blocks. If they are in a json block and are invalid json,
# this can cause downstream LLM integrations to fail (e.g. bedrock gateway) where response
# params are not valid json.
# Main case so far is no args = "" = invalid json.
log.debug(
f"Parsed args from {tool_args} to {tool_function_params}"
)
tool_call.setdefault("function", {})["arguments"] = json.dumps(
tool_function_params
)
tool_result = None
tool = None
tool_type = None
direct_tool = False
if tool_function_name in tools:
tool = tools[tool_function_name]
spec = tool.get("spec", {})
tool_type = tool.get("type", "")
direct_tool = tool.get("direct", False)
try:
allowed_params = (
spec.get("parameters", {})
.get("properties", {})
.keys()
)
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in allowed_params
}
if direct_tool:
tool_result = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"name": tool_function_name,
"params": tool_function_params,
"server": tool.get("server", {}),
"session_id": metadata.get(
"session_id", None
),
},
}
)
else:
tool_function = tool["callable"]
tool_result = await tool_function(
**tool_function_params
)
except Exception as e:
tool_result = str(e)
tool_result, tool_result_files, tool_result_embeds = (
process_tool_result(
request,
tool_function_name,
tool_result,
tool_type,
direct_tool,
metadata,
user,
)
)
results.append(
{
"tool_call_id": tool_call_id,
"content": tool_result or "",
**(
{"files": tool_result_files}
if tool_result_files
else {}
),
**(
{"embeds": tool_result_embeds}
if tool_result_embeds
else {}
),
}
)
content_blocks[-1]["results"] = results
content_blocks.append(
{
"type": "text",
"content": "",
}
)
await event_emitter(
{
"type": "chat:completion",
"data": {
"content": serialize_content_blocks(content_blocks),
},
}
)
try:
new_form_data = {
**form_data,
"model": model_id,
"stream": True,
"messages": [
*form_data["messages"],
*convert_content_blocks_to_messages(
content_blocks, True
),
],
}
res = await generate_chat_completion(
request,
new_form_data,
user,
)
if isinstance(res, StreamingResponse):
await stream_body_handler(res, new_form_data)
else:
break
except Exception as e:
log.debug(e)
break
if DETECT_CODE_INTERPRETER:
MAX_RETRIES = 5
retries = 0
while (
content_blocks[-1]["type"] == "code_interpreter"
and retries < MAX_RETRIES
):
await event_emitter(
{
"type": "chat:completion",
"data": {
"content": serialize_content_blocks(content_blocks),
},
}
)
retries += 1
log.debug(f"Attempt count: {retries}")
output = ""
try:
if content_blocks[-1]["attributes"].get("type") == "code":
code = content_blocks[-1]["content"]
if CODE_INTERPRETER_BLOCKED_MODULES:
blocking_code = textwrap.dedent(
f"""
import builtins
BLOCKED_MODULES = {CODE_INTERPRETER_BLOCKED_MODULES}
_real_import = builtins.__import__
def restricted_import(name, globals=None, locals=None, fromlist=(), level=0):
if name.split('.')[0] in BLOCKED_MODULES:
importer_name = globals.get('__name__') if globals else None
if importer_name == '__main__':
raise ImportError(
f"Direct import of module {{name}} is restricted."
)
return _real_import(name, globals, locals, fromlist, level)
builtins.__import__ = restricted_import
"""
)
code = blocking_code + "\n" + code
if (
request.app.state.config.CODE_INTERPRETER_ENGINE
== "pyodide"
):
output = await event_caller(
{
"type": "execute:python",
"data": {
"id": str(uuid4()),
"code": code,
"session_id": metadata.get(
"session_id", None
),
},
}
)
elif (
request.app.state.config.CODE_INTERPRETER_ENGINE
== "jupyter"
):
output = await execute_code_jupyter(
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
code,
(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
== "token"
else None
),
(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
== "password"
else None
),
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
)
else:
output = {
"stdout": "Code interpreter engine not configured."
}
log.debug(f"Code interpreter output: {output}")
if isinstance(output, dict):
stdout = output.get("stdout", "")
if isinstance(stdout, str):
stdoutLines = stdout.split("\n")
for idx, line in enumerate(stdoutLines):
if "data:image/png;base64" in line:
image_url = get_image_url_from_base64(
request,
line,
metadata,
user,
)
if image_url:
stdoutLines[idx] = (
f""
)
output["stdout"] = "\n".join(stdoutLines)
result = output.get("result", "")
if isinstance(result, str):
resultLines = result.split("\n")
for idx, line in enumerate(resultLines):
if "data:image/png;base64" in line:
image_url = get_image_url_from_base64(
request,
line,
metadata,
user,
)
resultLines[idx] = (
f""
)
output["result"] = "\n".join(resultLines)
except Exception as e:
output = str(e)
content_blocks[-1]["output"] = output
content_blocks.append(
{
"type": "text",
"content": "",
}
)
await event_emitter(
{
"type": "chat:completion",
"data": {
"content": serialize_content_blocks(content_blocks),
},
}
)
try:
new_form_data = {
**form_data,
"model": model_id,
"stream": True,
"messages": [
*form_data["messages"],
{
"role": "assistant",
"content": serialize_content_blocks(
content_blocks, raw=True
),
},
],
}
res = await generate_chat_completion(
request,
new_form_data,
user,
)
if isinstance(res, StreamingResponse):
await stream_body_handler(res, new_form_data)
else:
break
except Exception as e:
log.debug(e)
break
title = Chats.get_chat_title_by_id(metadata["chat_id"])
data = {
"done": True,
"content": serialize_content_blocks(content_blocks),
"title": title,
}
if not ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": serialize_content_blocks(content_blocks),
},
)
# Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
await post_webhook(
request.app.state.WEBUI_NAME,
webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{
"action": "chat",
"message": content,
"title": title,
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
},
)
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
await background_tasks_handler()
except asyncio.CancelledError:
log.warning("Task was cancelled!")
await event_emitter({"type": "chat:tasks:cancel"})
if not ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": serialize_content_blocks(content_blocks),
},
)
if response.background is not None:
await response.background()
return await response_handler(response, events)
else:
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n"
for event in events:
event, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="stream",
form_data=event,
extra_params=extra_params,
)
if event:
yield wrap_item(json.dumps(event))
async for data in original_generator:
data, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
if data:
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),
background=response.background,
)