mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
3282 lines
135 KiB
Python
3282 lines
135 KiB
Python
import time
|
||
import logging
|
||
import sys
|
||
import os
|
||
import base64
|
||
import textwrap
|
||
|
||
import asyncio
|
||
from aiocache import cached
|
||
from typing import Any, Optional
|
||
import random
|
||
import json
|
||
import html
|
||
import inspect
|
||
import re
|
||
import ast
|
||
|
||
from uuid import uuid4
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
|
||
from fastapi import Request, HTTPException
|
||
from fastapi.responses import HTMLResponse
|
||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||
|
||
|
||
from open_webui.models.oauth_sessions import OAuthSessions
|
||
from open_webui.models.chats import Chats
|
||
from open_webui.models.folders import Folders
|
||
from open_webui.models.users import Users
|
||
from open_webui.socket.main import (
|
||
get_event_call,
|
||
get_event_emitter,
|
||
get_active_status_by_user_id,
|
||
)
|
||
from open_webui.routers.tasks import (
|
||
generate_queries,
|
||
generate_title,
|
||
generate_follow_ups,
|
||
generate_image_prompt,
|
||
generate_chat_tags,
|
||
)
|
||
from open_webui.routers.retrieval import (
|
||
process_web_search,
|
||
SearchForm,
|
||
)
|
||
from open_webui.routers.images import (
|
||
load_b64_image_data,
|
||
image_generations,
|
||
GenerateImageForm,
|
||
upload_image,
|
||
)
|
||
from open_webui.routers.pipelines import (
|
||
process_pipeline_inlet_filter,
|
||
process_pipeline_outlet_filter,
|
||
)
|
||
from open_webui.models.memories import Memories
|
||
from open_webui.memory.mem0 import mem0_search
|
||
|
||
from open_webui.utils.webhook import post_webhook
|
||
from open_webui.utils.files import (
|
||
get_audio_url_from_base64,
|
||
get_file_url_from_base64,
|
||
get_image_url_from_base64,
|
||
)
|
||
|
||
|
||
from open_webui.models.users import UserModel
|
||
from open_webui.models.functions import Functions
|
||
from open_webui.models.models import Models
|
||
|
||
from open_webui.retrieval.utils import get_sources_from_items
|
||
|
||
|
||
from open_webui.utils.chat import generate_chat_completion
|
||
from open_webui.utils.task import (
|
||
get_task_model_id,
|
||
rag_template,
|
||
tools_function_calling_generation_template,
|
||
)
|
||
from open_webui.utils.misc import (
|
||
deep_update,
|
||
extract_urls,
|
||
get_message_list,
|
||
add_or_update_system_message,
|
||
add_or_update_user_message,
|
||
get_last_user_message,
|
||
get_last_user_message_item,
|
||
get_last_assistant_message,
|
||
get_system_message,
|
||
prepend_to_first_user_message_content,
|
||
convert_logit_bias_input_to_json,
|
||
get_content_from_message,
|
||
)
|
||
from open_webui.utils.tools import get_tools
|
||
from open_webui.utils.plugin import load_function_module_by_id
|
||
from open_webui.utils.filter import (
|
||
get_sorted_filter_ids,
|
||
process_filter_functions,
|
||
)
|
||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||
from open_webui.utils.payload import apply_system_prompt_to_body
|
||
from open_webui.utils.mcp.client import MCPClient
|
||
|
||
|
||
from open_webui.config import (
|
||
CACHE_DIR,
|
||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||
DEFAULT_CODE_INTERPRETER_PROMPT,
|
||
CODE_INTERPRETER_BLOCKED_MODULES,
|
||
)
|
||
from open_webui.env import (
|
||
SRC_LOG_LEVELS,
|
||
GLOBAL_LOG_LEVEL,
|
||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
||
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES,
|
||
BYPASS_MODEL_ACCESS_CONTROL,
|
||
ENABLE_REALTIME_CHAT_SAVE,
|
||
ENABLE_QUERIES_CACHE,
|
||
)
|
||
from open_webui.constants import TASKS
|
||
|
||
|
||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||
log = logging.getLogger(__name__)
|
||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||
|
||
|
||
DEFAULT_REASONING_TAGS = [
|
||
("<think>", "</think>"),
|
||
("<thinking>", "</thinking>"),
|
||
("<reason>", "</reason>"),
|
||
("<reasoning>", "</reasoning>"),
|
||
("<thought>", "</thought>"),
|
||
("<Thought>", "</Thought>"),
|
||
("<|begin_of_thought|>", "<|end_of_thought|>"),
|
||
("◁think▷", "◁/think▷"),
|
||
]
|
||
DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")]
|
||
DEFAULT_CODE_INTERPRETER_TAGS = [("<code_interpreter>", "</code_interpreter>")]
|
||
|
||
|
||
def process_tool_result(
|
||
request,
|
||
tool_function_name,
|
||
tool_result,
|
||
tool_type,
|
||
direct_tool=False,
|
||
metadata=None,
|
||
user=None,
|
||
):
|
||
tool_result_embeds = []
|
||
|
||
if isinstance(tool_result, HTMLResponse):
|
||
content_disposition = tool_result.headers.get("Content-Disposition", "")
|
||
if "inline" in content_disposition:
|
||
content = tool_result.body.decode("utf-8", "replace")
|
||
tool_result_embeds.append(content)
|
||
|
||
if 200 <= tool_result.status_code < 300:
|
||
tool_result = {
|
||
"status": "success",
|
||
"code": "ui_component",
|
||
"message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
|
||
}
|
||
elif 400 <= tool_result.status_code < 500:
|
||
tool_result = {
|
||
"status": "error",
|
||
"code": "ui_component",
|
||
"message": f"{tool_function_name}: Client error {tool_result.status_code} from embedded UI result.",
|
||
}
|
||
elif 500 <= tool_result.status_code < 600:
|
||
tool_result = {
|
||
"status": "error",
|
||
"code": "ui_component",
|
||
"message": f"{tool_function_name}: Server error {tool_result.status_code} from embedded UI result.",
|
||
}
|
||
else:
|
||
tool_result = {
|
||
"status": "error",
|
||
"code": "ui_component",
|
||
"message": f"{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.",
|
||
}
|
||
else:
|
||
tool_result = tool_result.body.decode("utf-8", "replace")
|
||
|
||
elif (tool_type == "external" and isinstance(tool_result, tuple)) or (
|
||
direct_tool and isinstance(tool_result, list) and len(tool_result) == 2
|
||
):
|
||
tool_result, tool_response_headers = tool_result
|
||
|
||
try:
|
||
if not isinstance(tool_response_headers, dict):
|
||
tool_response_headers = dict(tool_response_headers)
|
||
except Exception as e:
|
||
tool_response_headers = {}
|
||
log.debug(e)
|
||
|
||
if tool_response_headers and isinstance(tool_response_headers, dict):
|
||
content_disposition = tool_response_headers.get(
|
||
"Content-Disposition",
|
||
tool_response_headers.get("content-disposition", ""),
|
||
)
|
||
|
||
if "inline" in content_disposition:
|
||
content_type = tool_response_headers.get(
|
||
"Content-Type",
|
||
tool_response_headers.get("content-type", ""),
|
||
)
|
||
location = tool_response_headers.get(
|
||
"Location",
|
||
tool_response_headers.get("location", ""),
|
||
)
|
||
|
||
if "text/html" in content_type:
|
||
# Display as iframe embed
|
||
tool_result_embeds.append(tool_result)
|
||
tool_result = {
|
||
"status": "success",
|
||
"code": "ui_component",
|
||
"message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
|
||
}
|
||
elif location:
|
||
tool_result_embeds.append(location)
|
||
tool_result = {
|
||
"status": "success",
|
||
"code": "ui_component",
|
||
"message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
|
||
}
|
||
|
||
tool_result_files = []
|
||
|
||
if isinstance(tool_result, list):
|
||
if tool_type == "mcp": # MCP
|
||
tool_response = []
|
||
for item in tool_result:
|
||
if isinstance(item, dict):
|
||
if item.get("type") == "text":
|
||
text = item.get("text", "")
|
||
if isinstance(text, str):
|
||
try:
|
||
text = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
tool_response.append(text)
|
||
elif item.get("type") in ["image", "audio"]:
|
||
file_url = get_file_url_from_base64(
|
||
request,
|
||
f"data:{item.get('mimeType')};base64,{item.get('data', item.get('blob', ''))}",
|
||
{
|
||
"chat_id": metadata.get("chat_id", None),
|
||
"message_id": metadata.get("message_id", None),
|
||
"session_id": metadata.get("session_id", None),
|
||
"result": item,
|
||
},
|
||
user,
|
||
)
|
||
|
||
tool_result_files.append(
|
||
{
|
||
"type": item.get("type", "data"),
|
||
"url": file_url,
|
||
}
|
||
)
|
||
tool_result = tool_response[0] if len(tool_response) == 1 else tool_response
|
||
else: # OpenAPI
|
||
for item in tool_result:
|
||
if isinstance(item, str) and item.startswith("data:"):
|
||
tool_result_files.append(
|
||
{
|
||
"type": "data",
|
||
"content": item,
|
||
}
|
||
)
|
||
tool_result.remove(item)
|
||
|
||
if isinstance(tool_result, list):
|
||
tool_result = {"results": tool_result}
|
||
|
||
if isinstance(tool_result, dict) or isinstance(tool_result, list):
|
||
tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False)
|
||
|
||
return tool_result, tool_result_files, tool_result_embeds
|
||
|
||
|
||
async def chat_completion_tools_handler(
|
||
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
|
||
) -> tuple[dict, dict]:
|
||
async def get_content_from_response(response) -> Optional[str]:
|
||
content = None
|
||
if hasattr(response, "body_iterator"):
|
||
async for chunk in response.body_iterator:
|
||
data = json.loads(chunk.decode("utf-8", "replace"))
|
||
content = data["choices"][0]["message"]["content"]
|
||
|
||
# Cleanup any remaining background tasks if necessary
|
||
if response.background is not None:
|
||
await response.background()
|
||
else:
|
||
content = response["choices"][0]["message"]["content"]
|
||
return content
|
||
|
||
def get_tools_function_calling_payload(messages, task_model_id, content):
|
||
user_message = get_last_user_message(messages)
|
||
|
||
recent_messages = messages[-4:] if len(messages) > 4 else messages
|
||
chat_history = "\n".join(
|
||
f"{message['role'].upper()}: \"\"\"{get_content_from_message(message)}\"\"\""
|
||
for message in recent_messages
|
||
)
|
||
|
||
prompt = f"History:\n{chat_history}\nQuery: {user_message}"
|
||
|
||
return {
|
||
"model": task_model_id,
|
||
"messages": [
|
||
{"role": "system", "content": content},
|
||
{"role": "user", "content": f"Query: {prompt}"},
|
||
],
|
||
"stream": False,
|
||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||
}
|
||
|
||
event_caller = extra_params["__event_call__"]
|
||
event_emitter = extra_params["__event_emitter__"]
|
||
metadata = extra_params["__metadata__"]
|
||
|
||
task_model_id = get_task_model_id(
|
||
body["model"],
|
||
request.app.state.config.TASK_MODEL,
|
||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||
models,
|
||
)
|
||
|
||
skip_files = False
|
||
sources = []
|
||
|
||
specs = [tool["spec"] for tool in tools.values()]
|
||
tools_specs = json.dumps(specs)
|
||
|
||
if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
|
||
template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||
else:
|
||
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||
|
||
tools_function_calling_prompt = tools_function_calling_generation_template(
|
||
template, tools_specs
|
||
)
|
||
payload = get_tools_function_calling_payload(
|
||
body["messages"], task_model_id, tools_function_calling_prompt
|
||
)
|
||
|
||
try:
|
||
response = await generate_chat_completion(request, form_data=payload, user=user)
|
||
log.debug(f"{response=}")
|
||
content = await get_content_from_response(response)
|
||
log.debug(f"{content=}")
|
||
|
||
if not content:
|
||
return body, {}
|
||
|
||
try:
|
||
content = content[content.find("{") : content.rfind("}") + 1]
|
||
if not content:
|
||
raise Exception("No JSON object found in the response")
|
||
|
||
result = json.loads(content)
|
||
|
||
async def tool_call_handler(tool_call):
|
||
nonlocal skip_files
|
||
|
||
log.debug(f"{tool_call=}")
|
||
|
||
tool_function_name = tool_call.get("name", None)
|
||
if tool_function_name not in tools:
|
||
return body, {}
|
||
|
||
tool_function_params = tool_call.get("parameters", {})
|
||
|
||
tool = None
|
||
tool_type = ""
|
||
direct_tool = False
|
||
|
||
try:
|
||
tool = tools[tool_function_name]
|
||
tool_type = tool.get("type", "")
|
||
direct_tool = tool.get("direct", False)
|
||
|
||
spec = tool.get("spec", {})
|
||
allowed_params = (
|
||
spec.get("parameters", {}).get("properties", {}).keys()
|
||
)
|
||
tool_function_params = {
|
||
k: v
|
||
for k, v in tool_function_params.items()
|
||
if k in allowed_params
|
||
}
|
||
|
||
if tool.get("direct", False):
|
||
tool_result = await event_caller(
|
||
{
|
||
"type": "execute:tool",
|
||
"data": {
|
||
"id": str(uuid4()),
|
||
"name": tool_function_name,
|
||
"params": tool_function_params,
|
||
"server": tool.get("server", {}),
|
||
"session_id": metadata.get("session_id", None),
|
||
},
|
||
}
|
||
)
|
||
else:
|
||
tool_function = tool["callable"]
|
||
tool_result = await tool_function(**tool_function_params)
|
||
|
||
except Exception as e:
|
||
tool_result = str(e)
|
||
|
||
tool_result, tool_result_files, tool_result_embeds = (
|
||
process_tool_result(
|
||
request,
|
||
tool_function_name,
|
||
tool_result,
|
||
tool_type,
|
||
direct_tool,
|
||
metadata,
|
||
user,
|
||
)
|
||
)
|
||
|
||
if event_emitter:
|
||
if tool_result_files:
|
||
await event_emitter(
|
||
{
|
||
"type": "files",
|
||
"data": {
|
||
"files": tool_result_files,
|
||
},
|
||
}
|
||
)
|
||
|
||
if tool_result_embeds:
|
||
await event_emitter(
|
||
{
|
||
"type": "embeds",
|
||
"data": {
|
||
"embeds": tool_result_embeds,
|
||
},
|
||
}
|
||
)
|
||
|
||
print(
|
||
f"Tool {tool_function_name} result: {tool_result}",
|
||
tool_result_files,
|
||
tool_result_embeds,
|
||
)
|
||
|
||
if tool_result:
|
||
tool = tools[tool_function_name]
|
||
tool_id = tool.get("tool_id", "")
|
||
|
||
tool_name = (
|
||
f"{tool_id}/{tool_function_name}"
|
||
if tool_id
|
||
else f"{tool_function_name}"
|
||
)
|
||
|
||
# Citation is enabled for this tool
|
||
sources.append(
|
||
{
|
||
"source": {
|
||
"name": (f"{tool_name}"),
|
||
},
|
||
"document": [str(tool_result)],
|
||
"metadata": [
|
||
{
|
||
"source": (f"{tool_name}"),
|
||
"parameters": tool_function_params,
|
||
}
|
||
],
|
||
"tool_result": True,
|
||
}
|
||
)
|
||
|
||
# Citation is not enabled for this tool
|
||
body["messages"] = add_or_update_user_message(
|
||
f"\nTool `{tool_name}` Output: {tool_result}",
|
||
body["messages"],
|
||
)
|
||
|
||
if (
|
||
tools[tool_function_name]
|
||
.get("metadata", {})
|
||
.get("file_handler", False)
|
||
):
|
||
skip_files = True
|
||
|
||
# check if "tool_calls" in result
|
||
if result.get("tool_calls"):
|
||
for tool_call in result.get("tool_calls"):
|
||
await tool_call_handler(tool_call)
|
||
else:
|
||
await tool_call_handler(result)
|
||
|
||
except Exception as e:
|
||
log.debug(f"Error: {e}")
|
||
content = None
|
||
except Exception as e:
|
||
log.debug(f"Error: {e}")
|
||
content = None
|
||
|
||
log.debug(f"tool_contexts: {sources}")
|
||
|
||
if skip_files and "files" in body.get("metadata", {}):
|
||
del body["metadata"]["files"]
|
||
|
||
return body, {"sources": sources}
|
||
|
||
|
||
async def chat_memory_handler(
|
||
request: Request, form_data: dict, extra_params: dict, user, metadata
|
||
):
|
||
"""
|
||
聊天记忆处理器 - 注入用户手动保存的记忆 + Mem0 检索结果到当前对话上下文
|
||
|
||
新增行为:
|
||
1. memory 特性开启时,直接注入用户所有记忆条目(不再做 RAG Top-K)
|
||
2. 预留 Mem0 检索:使用最后一条用户消息查询 Mem0,返回的条目也一并注入
|
||
3. 上下文注入方式保持不变:统一写入 System Message
|
||
"""
|
||
user_message = get_last_user_message(form_data.get("messages", [])) or ""
|
||
|
||
# === 1. 获取用户全部记忆(不再截断 Top-K) ===
|
||
memories = Memories.get_memories_by_user_id(user.id) or []
|
||
|
||
# === 2. 预留的 Mem0 检索结果 ===
|
||
mem0_results = await mem0_search(user.id, metadata.get("chat_id"), user_message)
|
||
|
||
# === 3. 格式化记忆条目 ===
|
||
entries = []
|
||
|
||
# 3.1 用户记忆库全量
|
||
for mem in memories:
|
||
created_at_date = time.strftime("%Y-%m-%d", time.localtime(mem.created_at)) if mem.created_at else "Unknown Date"
|
||
entries.append(f"[{created_at_date}] {mem.content}")
|
||
|
||
# 3.2 Mem0 检索结果
|
||
for item in mem0_results:
|
||
entries.append(f"[Mem0] {item}")
|
||
|
||
if not entries:
|
||
return form_data
|
||
|
||
user_context = ""
|
||
for idx, entry in enumerate(entries):
|
||
user_context += f"{idx + 1}. {entry}\n"
|
||
|
||
# === 4. 将记忆注入到系统消息中 ===
|
||
form_data["messages"] = add_or_update_system_message(
|
||
f"User Context:\n{user_context}\n", form_data["messages"], append=True
|
||
)
|
||
|
||
return form_data
|
||
|
||
|
||
async def chat_web_search_handler(
|
||
request: Request, form_data: dict, extra_params: dict, user
|
||
):
|
||
event_emitter = extra_params["__event_emitter__"]
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "web_search",
|
||
"description": "Searching the web",
|
||
"done": False,
|
||
},
|
||
}
|
||
)
|
||
|
||
messages = form_data["messages"]
|
||
user_message = get_last_user_message(messages)
|
||
|
||
queries = []
|
||
try:
|
||
res = await generate_queries(
|
||
request,
|
||
{
|
||
"model": form_data["model"],
|
||
"messages": messages,
|
||
"prompt": user_message,
|
||
"type": "web_search",
|
||
},
|
||
user,
|
||
)
|
||
|
||
response = res["choices"][0]["message"]["content"]
|
||
|
||
try:
|
||
bracket_start = response.find("{")
|
||
bracket_end = response.rfind("}") + 1
|
||
|
||
if bracket_start == -1 or bracket_end == -1:
|
||
raise Exception("No JSON object found in the response")
|
||
|
||
response = response[bracket_start:bracket_end]
|
||
queries = json.loads(response)
|
||
queries = queries.get("queries", [])
|
||
except Exception as e:
|
||
queries = [response]
|
||
|
||
if ENABLE_QUERIES_CACHE:
|
||
request.state.cached_queries = queries
|
||
|
||
except Exception as e:
|
||
log.exception(e)
|
||
queries = [user_message]
|
||
|
||
# Check if generated queries are empty
|
||
if len(queries) == 1 and queries[0].strip() == "":
|
||
queries = [user_message]
|
||
|
||
# Check if queries are not found
|
||
if len(queries) == 0:
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "web_search",
|
||
"description": "No search query generated",
|
||
"done": True,
|
||
},
|
||
}
|
||
)
|
||
return form_data
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "web_search_queries_generated",
|
||
"queries": queries,
|
||
"done": False,
|
||
},
|
||
}
|
||
)
|
||
|
||
try:
|
||
results = await process_web_search(
|
||
request,
|
||
SearchForm(queries=queries),
|
||
user=user,
|
||
)
|
||
|
||
if results:
|
||
files = form_data.get("files", [])
|
||
|
||
if results.get("collection_names"):
|
||
for col_idx, collection_name in enumerate(
|
||
results.get("collection_names")
|
||
):
|
||
files.append(
|
||
{
|
||
"collection_name": collection_name,
|
||
"name": ", ".join(queries),
|
||
"type": "web_search",
|
||
"urls": results["filenames"],
|
||
"queries": queries,
|
||
}
|
||
)
|
||
elif results.get("docs"):
|
||
# Invoked when bypass embedding and retrieval is set to True
|
||
docs = results["docs"]
|
||
files.append(
|
||
{
|
||
"docs": docs,
|
||
"name": ", ".join(queries),
|
||
"type": "web_search",
|
||
"urls": results["filenames"],
|
||
"queries": queries,
|
||
}
|
||
)
|
||
|
||
form_data["files"] = files
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "web_search",
|
||
"description": "Searched {{count}} sites",
|
||
"urls": results["filenames"],
|
||
"items": results.get("items", []),
|
||
"done": True,
|
||
},
|
||
}
|
||
)
|
||
else:
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "web_search",
|
||
"description": "No search results found",
|
||
"done": True,
|
||
"error": True,
|
||
},
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
log.exception(e)
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "web_search",
|
||
"description": "An error occurred while searching the web",
|
||
"queries": queries,
|
||
"done": True,
|
||
"error": True,
|
||
},
|
||
}
|
||
)
|
||
|
||
return form_data
|
||
|
||
|
||
async def chat_image_generation_handler(
|
||
request: Request, form_data: dict, extra_params: dict, user
|
||
):
|
||
__event_emitter__ = extra_params["__event_emitter__"]
|
||
await __event_emitter__(
|
||
{
|
||
"type": "status",
|
||
"data": {"description": "Creating image", "done": False},
|
||
}
|
||
)
|
||
|
||
messages = form_data["messages"]
|
||
user_message = get_last_user_message(messages)
|
||
|
||
prompt = user_message
|
||
negative_prompt = ""
|
||
|
||
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
|
||
try:
|
||
res = await generate_image_prompt(
|
||
request,
|
||
{
|
||
"model": form_data["model"],
|
||
"messages": messages,
|
||
},
|
||
user,
|
||
)
|
||
|
||
response = res["choices"][0]["message"]["content"]
|
||
|
||
try:
|
||
bracket_start = response.find("{")
|
||
bracket_end = response.rfind("}") + 1
|
||
|
||
if bracket_start == -1 or bracket_end == -1:
|
||
raise Exception("No JSON object found in the response")
|
||
|
||
response = response[bracket_start:bracket_end]
|
||
response = json.loads(response)
|
||
prompt = response.get("prompt", [])
|
||
except Exception as e:
|
||
prompt = user_message
|
||
|
||
except Exception as e:
|
||
log.exception(e)
|
||
prompt = user_message
|
||
|
||
system_message_content = ""
|
||
|
||
try:
|
||
images = await image_generations(
|
||
request=request,
|
||
form_data=GenerateImageForm(**{"prompt": prompt}),
|
||
user=user,
|
||
)
|
||
|
||
await __event_emitter__(
|
||
{
|
||
"type": "status",
|
||
"data": {"description": "Image created", "done": True},
|
||
}
|
||
)
|
||
|
||
await __event_emitter__(
|
||
{
|
||
"type": "files",
|
||
"data": {
|
||
"files": [
|
||
{
|
||
"type": "image",
|
||
"url": image["url"],
|
||
}
|
||
for image in images
|
||
]
|
||
},
|
||
}
|
||
)
|
||
|
||
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
|
||
except Exception as e:
|
||
log.exception(e)
|
||
await __event_emitter__(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"description": f"An error occurred while generating an image",
|
||
"done": True,
|
||
},
|
||
}
|
||
)
|
||
|
||
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
|
||
|
||
if system_message_content:
|
||
form_data["messages"] = add_or_update_system_message(
|
||
system_message_content, form_data["messages"]
|
||
)
|
||
|
||
return form_data
|
||
|
||
|
||
async def chat_completion_files_handler(
|
||
request: Request, body: dict, extra_params: dict, user: UserModel
|
||
) -> tuple[dict, dict[str, list]]:
|
||
__event_emitter__ = extra_params["__event_emitter__"]
|
||
sources = []
|
||
|
||
if files := body.get("metadata", {}).get("files", None):
|
||
# Check if all files are in full context mode
|
||
all_full_context = all(item.get("context") == "full" for item in files)
|
||
|
||
queries = []
|
||
if not all_full_context:
|
||
try:
|
||
queries_response = await generate_queries(
|
||
request,
|
||
{
|
||
"model": body["model"],
|
||
"messages": body["messages"],
|
||
"type": "retrieval",
|
||
},
|
||
user,
|
||
)
|
||
queries_response = queries_response["choices"][0]["message"]["content"]
|
||
|
||
try:
|
||
bracket_start = queries_response.find("{")
|
||
bracket_end = queries_response.rfind("}") + 1
|
||
|
||
if bracket_start == -1 or bracket_end == -1:
|
||
raise Exception("No JSON object found in the response")
|
||
|
||
queries_response = queries_response[bracket_start:bracket_end]
|
||
queries_response = json.loads(queries_response)
|
||
except Exception as e:
|
||
queries_response = {"queries": [queries_response]}
|
||
|
||
queries = queries_response.get("queries", [])
|
||
except:
|
||
pass
|
||
|
||
await __event_emitter__(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "queries_generated",
|
||
"queries": queries,
|
||
"done": False,
|
||
},
|
||
}
|
||
)
|
||
|
||
if len(queries) == 0:
|
||
queries = [get_last_user_message(body["messages"])]
|
||
|
||
try:
|
||
# Offload get_sources_from_items to a separate thread
|
||
loop = asyncio.get_running_loop()
|
||
with ThreadPoolExecutor() as executor:
|
||
sources = await loop.run_in_executor(
|
||
executor,
|
||
lambda: get_sources_from_items(
|
||
request=request,
|
||
items=files,
|
||
queries=queries,
|
||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||
query, prefix=prefix, user=user
|
||
),
|
||
k=request.app.state.config.TOP_K,
|
||
reranking_function=(
|
||
(
|
||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||
sentences, user=user
|
||
)
|
||
)
|
||
if request.app.state.RERANKING_FUNCTION
|
||
else None
|
||
),
|
||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||
full_context=all_full_context
|
||
or request.app.state.config.RAG_FULL_CONTEXT,
|
||
user=user,
|
||
),
|
||
)
|
||
except Exception as e:
|
||
log.exception(e)
|
||
|
||
log.debug(f"rag_contexts:sources: {sources}")
|
||
|
||
unique_ids = set()
|
||
for source in sources or []:
|
||
if not source or len(source.keys()) == 0:
|
||
continue
|
||
|
||
documents = source.get("document") or []
|
||
metadatas = source.get("metadata") or []
|
||
src_info = source.get("source") or {}
|
||
|
||
for index, _ in enumerate(documents):
|
||
metadata = metadatas[index] if index < len(metadatas) else None
|
||
_id = (
|
||
(metadata or {}).get("source")
|
||
or (src_info or {}).get("id")
|
||
or "N/A"
|
||
)
|
||
unique_ids.add(_id)
|
||
|
||
sources_count = len(unique_ids)
|
||
await __event_emitter__(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "sources_retrieved",
|
||
"count": sources_count,
|
||
"done": True,
|
||
},
|
||
}
|
||
)
|
||
|
||
return body, {"sources": sources}
|
||
|
||
|
||
def apply_params_to_form_data(form_data, model):
|
||
params = form_data.pop("params", {})
|
||
custom_params = params.pop("custom_params", {})
|
||
|
||
open_webui_params = {
|
||
"stream_response": bool,
|
||
"stream_delta_chunk_size": int,
|
||
"function_calling": str,
|
||
"reasoning_tags": list,
|
||
"system": str,
|
||
}
|
||
|
||
for key in list(params.keys()):
|
||
if key in open_webui_params:
|
||
del params[key]
|
||
|
||
if custom_params:
|
||
# Attempt to parse custom_params if they are strings
|
||
for key, value in custom_params.items():
|
||
if isinstance(value, str):
|
||
try:
|
||
# Attempt to parse the string as JSON
|
||
custom_params[key] = json.loads(value)
|
||
except json.JSONDecodeError:
|
||
# If it fails, keep the original string
|
||
pass
|
||
|
||
# If custom_params are provided, merge them into params
|
||
params = deep_update(params, custom_params)
|
||
|
||
if model.get("owned_by") == "ollama":
|
||
# Ollama specific parameters
|
||
form_data["options"] = params
|
||
else:
|
||
if isinstance(params, dict):
|
||
for key, value in params.items():
|
||
if value is not None:
|
||
form_data[key] = value
|
||
|
||
if "logit_bias" in params and params["logit_bias"] is not None:
|
||
try:
|
||
form_data["logit_bias"] = json.loads(
|
||
convert_logit_bias_input_to_json(params["logit_bias"])
|
||
)
|
||
except Exception as e:
|
||
log.exception(f"Error parsing logit_bias: {e}")
|
||
|
||
return form_data
|
||
|
||
|
||
async def process_chat_payload(request, form_data, user, metadata, model):
|
||
"""
|
||
处理聊天请求的 Payload - 执行 Pipeline、Filter、功能增强和工具注入
|
||
|
||
这是聊天请求预处理的核心函数,按以下顺序执行:
|
||
1. Pipeline Inlet (管道入口) - 自定义 Python 插件预处理
|
||
2. Filter Inlet (过滤器入口) - 函数过滤器预处理
|
||
3. Chat Memory (记忆) - 注入历史对话记忆
|
||
4. Chat Web Search (网页搜索) - 执行网络搜索并注入结果
|
||
5. Chat Image Generation (图像生成) - 处理图像生成请求
|
||
6. Chat Code Interpreter (代码解释器) - 注入代码执行提示词
|
||
7. Chat Tools Function Calling (工具调用) - 处理函数/工具调用
|
||
8. Chat Files (文件处理) - 处理上传文件、知识库文件、RAG 检索
|
||
|
||
Args:
|
||
request: FastAPI Request 对象
|
||
form_data: OpenAI 格式的聊天请求数据
|
||
user: 用户对象
|
||
metadata: 元数据(chat_id, message_id, tool_ids, files 等)
|
||
model: 模型配置对象
|
||
|
||
Returns:
|
||
tuple: (form_data, metadata, events)
|
||
- form_data: 处理后的请求数据
|
||
- metadata: 更新后的元数据
|
||
- events: 需要发送给前端的事件列表(如引用来源)
|
||
"""
|
||
# === 1. 应用模型参数到请求 ===
|
||
form_data = apply_params_to_form_data(form_data, model)
|
||
log.debug(f"form_data: {form_data}")
|
||
|
||
# === 2. 处理 System Prompt 变量替换 ===
|
||
system_message = get_system_message(form_data.get("messages", []))
|
||
if system_message: # Chat Controls/User Settings
|
||
try:
|
||
# 替换 system prompt 中的变量(如 {{USER_NAME}}, {{CURRENT_DATE}})
|
||
form_data = apply_system_prompt_to_body(
|
||
system_message.get("content"), form_data, metadata, user, replace=True
|
||
)
|
||
except:
|
||
pass
|
||
|
||
# === 3. 初始化事件发射器和回调 ===
|
||
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
|
||
event_call = get_event_call(metadata) # 事件调用函数
|
||
|
||
# === 4. 获取 OAuth Token ===
|
||
oauth_token = None
|
||
try:
|
||
if request.cookies.get("oauth_session_id", None):
|
||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||
user.id,
|
||
request.cookies.get("oauth_session_id", None),
|
||
)
|
||
except Exception as e:
|
||
log.error(f"Error getting OAuth token: {e}")
|
||
|
||
# === 5. 构建额外参数(供 Pipeline/Filter/Tools 使用)===
|
||
extra_params = {
|
||
"__event_emitter__": event_emitter, # 用于向前端发送实时事件
|
||
"__event_call__": event_call, # 用于调用事件回调
|
||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||
"__metadata__": metadata,
|
||
"__request__": request,
|
||
"__model__": model,
|
||
"__oauth_token__": oauth_token,
|
||
}
|
||
|
||
# === 6. 确定模型列表和任务模型 ===
|
||
# Initialize events to store additional event to be sent to the client
|
||
# Initialize contexts and citation
|
||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||
# Direct 模式:使用用户直连的模型
|
||
models = {
|
||
request.state.model["id"]: request.state.model,
|
||
}
|
||
else:
|
||
# 标准模式:使用平台所有模型
|
||
models = request.app.state.MODELS
|
||
|
||
# 获取任务模型 ID(用于工具调用、标题生成等后台任务)
|
||
task_model_id = get_task_model_id(
|
||
form_data["model"],
|
||
request.app.state.config.TASK_MODEL,
|
||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||
models,
|
||
)
|
||
|
||
# === 7. 初始化事件和引用来源列表 ===
|
||
events = [] # 需要发送给前端的事件(如 sources 引用)
|
||
sources = [] # RAG 检索到的文档来源
|
||
|
||
# === 8. Folder "Project" 处理 - 注入文件夹的 System Prompt 和文件 ===
|
||
# Check if the request has chat_id and is inside of a folder
|
||
chat_id = metadata.get("chat_id", None)
|
||
if False:
|
||
if chat_id and user:
|
||
chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
|
||
if chat and chat.folder_id:
|
||
folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
|
||
|
||
if folder and folder.data:
|
||
# 注入文件夹的 system prompt
|
||
if "system_prompt" in folder.data:
|
||
form_data = apply_system_prompt_to_body(
|
||
folder.data["system_prompt"], form_data, metadata, user
|
||
)
|
||
# 注入文件夹关联的文件
|
||
if "files" in folder.data:
|
||
form_data["files"] = [
|
||
*folder.data["files"],
|
||
*form_data.get("files", []),
|
||
]
|
||
|
||
# === 9. Model "Knowledge" 处理 - 注入模型绑定的知识库 ===
|
||
user_message = get_last_user_message(form_data["messages"])
|
||
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
|
||
|
||
if model_knowledge:
|
||
# 向前端发送知识库搜索状态
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "knowledge_search",
|
||
"query": user_message,
|
||
"done": False,
|
||
},
|
||
}
|
||
)
|
||
|
||
knowledge_files = []
|
||
for item in model_knowledge:
|
||
# 处理旧格式的 collection_name
|
||
if item.get("collection_name"):
|
||
knowledge_files.append(
|
||
{
|
||
"id": item.get("collection_name"),
|
||
"name": item.get("name"),
|
||
"legacy": True,
|
||
}
|
||
)
|
||
# 处理新格式的 collection_names(多个集合)
|
||
elif item.get("collection_names"):
|
||
knowledge_files.append(
|
||
{
|
||
"name": item.get("name"),
|
||
"type": "collection",
|
||
"collection_names": item.get("collection_names"),
|
||
"legacy": True,
|
||
}
|
||
)
|
||
else:
|
||
knowledge_files.append(item)
|
||
|
||
# 合并模型知识库文件和用户上传文件
|
||
files = form_data.get("files", [])
|
||
files.extend(knowledge_files)
|
||
form_data["files"] = files
|
||
|
||
variables = form_data.pop("variables", None)
|
||
|
||
# === 10. Pipeline Inlet 处理 - 执行自定义 Python 插件 ===
|
||
# Process the form_data through the pipeline
|
||
try:
|
||
form_data = await process_pipeline_inlet_filter(
|
||
request, form_data, user, models
|
||
)
|
||
except Exception as e:
|
||
raise e
|
||
|
||
# === 11. Filter Inlet 处理 - 执行函数过滤器 ===
|
||
try:
|
||
filter_functions = [
|
||
Functions.get_function_by_id(filter_id)
|
||
for filter_id in get_sorted_filter_ids(
|
||
request, model, metadata.get("filter_ids", [])
|
||
)
|
||
]
|
||
|
||
form_data, flags = await process_filter_functions(
|
||
request=request,
|
||
filter_functions=filter_functions,
|
||
filter_type="inlet",
|
||
form_data=form_data,
|
||
extra_params=extra_params,
|
||
)
|
||
except Exception as e:
|
||
raise Exception(f"{e}")
|
||
|
||
# === 12. 功能增强处理 (Features) ===
|
||
features = form_data.pop("features", None)
|
||
if features:
|
||
# 12.1 记忆功能 - 注入历史对话记忆
|
||
if "memory" in features and features["memory"]:
|
||
form_data = await chat_memory_handler(
|
||
request, form_data, extra_params, user, metadata
|
||
)
|
||
|
||
# 12.2 网页搜索功能 - 执行网络搜索
|
||
if "web_search" in features and features["web_search"]:
|
||
form_data = await chat_web_search_handler(
|
||
request, form_data, extra_params, user
|
||
)
|
||
|
||
# 12.3 图像生成功能 - 处理图像生成请求
|
||
if "image_generation" in features and features["image_generation"]:
|
||
form_data = await chat_image_generation_handler(
|
||
request, form_data, extra_params, user
|
||
)
|
||
|
||
# 12.4 代码解释器功能 - 注入代码执行提示词
|
||
if "code_interpreter" in features and features["code_interpreter"]:
|
||
form_data["messages"] = add_or_update_user_message(
|
||
(
|
||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE
|
||
if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != ""
|
||
else DEFAULT_CODE_INTERPRETER_PROMPT
|
||
),
|
||
form_data["messages"],
|
||
)
|
||
|
||
# === 13. 提取工具和文件信息 ===
|
||
tool_ids = form_data.pop("tool_ids", None)
|
||
files = form_data.pop("files", None)
|
||
|
||
# === 14. 文件夹类型文件展开 ===
|
||
prompt = get_last_user_message(form_data["messages"])
|
||
|
||
if files:
|
||
if not files:
|
||
files = []
|
||
|
||
for file_item in files:
|
||
# 如果文件类型是 folder,展开其包含的所有文件
|
||
if file_item.get("type", "file") == "folder":
|
||
# Get folder files
|
||
folder_id = file_item.get("id", None)
|
||
if folder_id:
|
||
folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id)
|
||
if folder and folder.data and "files" in folder.data:
|
||
# 移除文件夹项,添加文件夹内的文件
|
||
files = [f for f in files if f.get("id", None) != folder_id]
|
||
files = [*files, *folder.data["files"]]
|
||
|
||
# 去重文件(基于文件内容)
|
||
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
||
|
||
# === 15. 更新元数据 ===
|
||
metadata = {
|
||
**metadata,
|
||
"tool_ids": tool_ids,
|
||
"files": files,
|
||
}
|
||
form_data["metadata"] = metadata
|
||
|
||
# === 16. 准备工具字典 ===
|
||
# Server side tools
|
||
tool_ids = metadata.get("tool_ids", None) # 服务器端工具 ID
|
||
# Client side tools
|
||
direct_tool_servers = metadata.get("tool_servers", None) # 客户端直连工具服务器
|
||
|
||
log.debug(f"{tool_ids=}")
|
||
log.debug(f"{direct_tool_servers=}")
|
||
|
||
tools_dict = {} # 所有工具的字典
|
||
|
||
mcp_clients = {} # MCP (Model Context Protocol) 客户端
|
||
mcp_tools_dict = {} # MCP 工具字典
|
||
|
||
if tool_ids:
|
||
for tool_id in tool_ids:
|
||
# === 16.1 处理 MCP (Model Context Protocol) 工具 ===
|
||
if tool_id.startswith("server:mcp:"):
|
||
try:
|
||
server_id = tool_id[len("server:mcp:") :]
|
||
|
||
# 查找 MCP 服务器连接配置
|
||
mcp_server_connection = None
|
||
for (
|
||
server_connection
|
||
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||
if (
|
||
server_connection.get("type", "") == "mcp"
|
||
and server_connection.get("info", {}).get("id") == server_id
|
||
):
|
||
mcp_server_connection = server_connection
|
||
break
|
||
|
||
if not mcp_server_connection:
|
||
log.error(f"MCP server with id {server_id} not found")
|
||
continue
|
||
|
||
# 处理认证类型
|
||
auth_type = mcp_server_connection.get("auth_type", "")
|
||
|
||
headers = {}
|
||
if auth_type == "bearer":
|
||
headers["Authorization"] = (
|
||
f"Bearer {mcp_server_connection.get('key', '')}"
|
||
)
|
||
elif auth_type == "none":
|
||
# 无需认证
|
||
pass
|
||
elif auth_type == "session":
|
||
headers["Authorization"] = (
|
||
f"Bearer {request.state.token.credentials}"
|
||
)
|
||
elif auth_type == "system_oauth":
|
||
oauth_token = extra_params.get("__oauth_token__", None)
|
||
if oauth_token:
|
||
headers["Authorization"] = (
|
||
f"Bearer {oauth_token.get('access_token', '')}"
|
||
)
|
||
elif auth_type == "oauth_2.1":
|
||
try:
|
||
splits = server_id.split(":")
|
||
server_id = splits[-1] if len(splits) > 1 else server_id
|
||
|
||
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
|
||
user.id, f"mcp:{server_id}"
|
||
)
|
||
|
||
if oauth_token:
|
||
headers["Authorization"] = (
|
||
f"Bearer {oauth_token.get('access_token', '')}"
|
||
)
|
||
except Exception as e:
|
||
log.error(f"Error getting OAuth token: {e}")
|
||
oauth_token = None
|
||
|
||
# 连接到 MCP 服务器
|
||
mcp_clients[server_id] = MCPClient()
|
||
await mcp_clients[server_id].connect(
|
||
url=mcp_server_connection.get("url", ""),
|
||
headers=headers if headers else None,
|
||
)
|
||
|
||
# 获取 MCP 工具列表并注册
|
||
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||
for tool_spec in tool_specs:
|
||
|
||
def make_tool_function(client, function_name):
|
||
"""为每个 MCP 工具创建异步调用函数"""
|
||
async def tool_function(**kwargs):
|
||
return await client.call_tool(
|
||
function_name,
|
||
function_args=kwargs,
|
||
)
|
||
|
||
return tool_function
|
||
|
||
tool_function = make_tool_function(
|
||
mcp_clients[server_id], tool_spec["name"]
|
||
)
|
||
|
||
# 注册 MCP 工具到工具字典
|
||
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
||
"spec": {
|
||
**tool_spec,
|
||
"name": f"{server_id}_{tool_spec['name']}",
|
||
},
|
||
"callable": tool_function,
|
||
"type": "mcp",
|
||
"client": mcp_clients[server_id],
|
||
"direct": False,
|
||
}
|
||
except Exception as e:
|
||
log.debug(e)
|
||
continue
|
||
|
||
# === 16.2 获取标准工具(Function Tools)===
|
||
tools_dict = await get_tools(
|
||
request,
|
||
tool_ids,
|
||
user,
|
||
{
|
||
**extra_params,
|
||
"__model__": models[task_model_id],
|
||
"__messages__": form_data["messages"],
|
||
"__files__": metadata.get("files", []),
|
||
},
|
||
)
|
||
# 合并 MCP 工具
|
||
if mcp_tools_dict:
|
||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||
|
||
# === 16.3 处理客户端直连工具服务器 ===
|
||
if direct_tool_servers:
|
||
for tool_server in direct_tool_servers:
|
||
tool_specs = tool_server.pop("specs", [])
|
||
|
||
for tool in tool_specs:
|
||
tools_dict[tool["name"]] = {
|
||
"spec": tool,
|
||
"direct": True,
|
||
"server": tool_server,
|
||
}
|
||
|
||
# 保存 MCP 客户端到元数据(用于最后清理)
|
||
if mcp_clients:
|
||
metadata["mcp_clients"] = mcp_clients
|
||
|
||
# === 17. 工具调用处理 ===
|
||
if tools_dict:
|
||
if metadata.get("params", {}).get("function_calling") == "native":
|
||
# 原生函数调用模式:直接传递给 LLM
|
||
metadata["tools"] = tools_dict
|
||
form_data["tools"] = [
|
||
{"type": "function", "function": tool.get("spec", {})}
|
||
for tool in tools_dict.values()
|
||
]
|
||
else:
|
||
# 默认模式:通过 Prompt 实现工具调用
|
||
try:
|
||
form_data, flags = await chat_completion_tools_handler(
|
||
request, form_data, extra_params, user, models, tools_dict
|
||
)
|
||
sources.extend(flags.get("sources", []))
|
||
except Exception as e:
|
||
log.exception(e)
|
||
|
||
# === 18. 文件处理 - RAG 检索 ===
|
||
try:
|
||
form_data, flags = await chat_completion_files_handler(
|
||
request, form_data, extra_params, user
|
||
)
|
||
sources.extend(flags.get("sources", []))
|
||
except Exception as e:
|
||
log.exception(e)
|
||
|
||
# === 19. 构建上下文字符串并注入到消息 ===
|
||
# If context is not empty, insert it into the messages
|
||
if len(sources) > 0:
|
||
context_string = ""
|
||
citation_idx_map = {} # 引用索引映射(文档 ID → 引用编号)
|
||
|
||
# 遍历所有来源,构建上下文字符串
|
||
for source in sources:
|
||
if "document" in source:
|
||
for document_text, document_metadata in zip(
|
||
source["document"], source["metadata"]
|
||
):
|
||
source_name = source.get("source", {}).get("name", None)
|
||
source_id = (
|
||
document_metadata.get("source", None)
|
||
or source.get("source", {}).get("id", None)
|
||
or "N/A"
|
||
)
|
||
|
||
# 为每个来源分配唯一的引用编号
|
||
if source_id not in citation_idx_map:
|
||
citation_idx_map[source_id] = len(citation_idx_map) + 1
|
||
|
||
# 构建 XML 格式的来源标签
|
||
context_string += (
|
||
f'<source id="{citation_idx_map[source_id]}"'
|
||
+ (f' name="{source_name}"' if source_name else "")
|
||
+ f">{document_text}</source>\n"
|
||
)
|
||
|
||
context_string = context_string.strip()
|
||
if prompt is None:
|
||
raise Exception("No user message found")
|
||
|
||
# 使用 RAG 模板将上下文注入到用户消息中
|
||
if context_string != "":
|
||
form_data["messages"] = add_or_update_user_message(
|
||
rag_template(
|
||
request.app.state.config.RAG_TEMPLATE,
|
||
context_string,
|
||
prompt,
|
||
),
|
||
form_data["messages"],
|
||
append=False,
|
||
)
|
||
|
||
# === 20. 整理引用来源并添加到事件 ===
|
||
# If there are citations, add them to the data_items
|
||
sources = [
|
||
source
|
||
for source in sources
|
||
if source.get("source", {}).get("name", "")
|
||
or source.get("source", {}).get("id", "")
|
||
]
|
||
|
||
if len(sources) > 0:
|
||
events.append({"sources": sources})
|
||
|
||
# === 21. 完成知识库搜索状态 ===
|
||
if model_knowledge:
|
||
await event_emitter(
|
||
{
|
||
"type": "status",
|
||
"data": {
|
||
"action": "knowledge_search",
|
||
"query": user_message,
|
||
"done": True,
|
||
"hidden": True,
|
||
},
|
||
}
|
||
)
|
||
|
||
return form_data, metadata, events
|
||
|
||
|
||
async def process_chat_response(
|
||
request, response, form_data, user, metadata, model, events, tasks
|
||
):
|
||
"""
|
||
处理聊天响应 - 规范化/分发聊天完成响应
|
||
|
||
这是聊天响应后处理的核心函数,负责:
|
||
1. 处理流式(SSE/ndjson)和非流式(JSON)响应
|
||
2. 通过 WebSocket 发送事件到前端
|
||
3. 持久化消息/错误/元数据到数据库
|
||
4. 触发后台任务(标题生成、标签生成、Follow-ups)
|
||
5. 流式响应时:消费上游数据块、重建内容、处理工具调用、转发增量
|
||
|
||
Args:
|
||
request: FastAPI Request 对象
|
||
response: 上游响应(dict/JSONResponse/StreamingResponse)
|
||
form_data: 原始请求 payload(可能被上游修改)
|
||
user: 已验证的用户对象
|
||
metadata: 早期收集的聊天/会话上下文
|
||
model: 解析后的模型配置
|
||
events: 需要发送的额外事件(如 sources 引用)
|
||
tasks: 可选的后台任务(title/tags/follow-ups)
|
||
|
||
Returns:
|
||
- 非流式: dict (OpenAI JSON 格式)
|
||
- 流式: StreamingResponse (SSE/ndjson 格式)
|
||
"""
|
||
|
||
# === 内部函数:后台任务处理器 ===
|
||
async def background_tasks_handler():
|
||
"""
|
||
在响应完成后执行后台任务:
|
||
1. Follow-ups 生成 - 生成后续问题建议
|
||
2. Title 生成 - 自动生成聊天标题
|
||
3. Tags 生成 - 自动生成聊天标签
|
||
"""
|
||
message = None
|
||
messages = []
|
||
|
||
# 获取消息历史
|
||
if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"):
|
||
# 从数据库获取持久化的聊天历史
|
||
messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
|
||
message = messages_map.get(metadata["message_id"]) if messages_map else None
|
||
|
||
message_list = get_message_list(messages_map, metadata["message_id"])
|
||
|
||
# 清理消息内容:移除 details 标签和文件
|
||
# get_message_list 创建新列表,不影响原始消息
|
||
messages = []
|
||
for message in message_list:
|
||
content = message.get("content", "")
|
||
# 处理多模态内容(图片 + 文本)
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
content = item["text"]
|
||
break
|
||
|
||
# 移除 <details> 标签和 Markdown 图片
|
||
if isinstance(content, str):
|
||
content = re.sub(
|
||
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
|
||
"",
|
||
content,
|
||
flags=re.S | re.I,
|
||
).strip()
|
||
|
||
messages.append(
|
||
{
|
||
**message,
|
||
"role": message.get("role", "assistant"), # 安全回退
|
||
"content": content,
|
||
}
|
||
)
|
||
else:
|
||
# 临时聊天(local:):从 form_data 获取
|
||
message = get_last_user_message_item(form_data.get("messages", []))
|
||
messages = form_data.get("messages", [])
|
||
if message:
|
||
message["model"] = form_data.get("model")
|
||
|
||
# 执行后台任务
|
||
if message and "model" in message:
|
||
if tasks and messages:
|
||
# === 任务 1: Follow-ups 生成 ===
|
||
if (
|
||
TASKS.FOLLOW_UP_GENERATION in tasks
|
||
and tasks[TASKS.FOLLOW_UP_GENERATION]
|
||
):
|
||
res = await generate_follow_ups(
|
||
request,
|
||
{
|
||
"model": message["model"],
|
||
"messages": messages,
|
||
"message_id": metadata["message_id"],
|
||
"chat_id": metadata["chat_id"],
|
||
},
|
||
user,
|
||
)
|
||
|
||
if res and isinstance(res, dict):
|
||
if len(res.get("choices", [])) == 1:
|
||
response_message = res.get("choices", [])[0].get(
|
||
"message", {}
|
||
)
|
||
|
||
follow_ups_string = response_message.get(
|
||
"content"
|
||
) or response_message.get("reasoning_content", "")
|
||
else:
|
||
follow_ups_string = ""
|
||
|
||
# 提取 JSON 对象(从第一个 { 到最后一个 })
|
||
follow_ups_string = follow_ups_string[
|
||
follow_ups_string.find("{") : follow_ups_string.rfind("}")
|
||
+ 1
|
||
]
|
||
|
||
try:
|
||
follow_ups = json.loads(follow_ups_string).get(
|
||
"follow_ups", []
|
||
)
|
||
# 通过 WebSocket 发送 Follow-ups
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:message:follow_ups",
|
||
"data": {
|
||
"follow_ups": follow_ups,
|
||
},
|
||
}
|
||
)
|
||
|
||
# 持久化到数据库
|
||
if not metadata.get("chat_id", "").startswith("local:"):
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"followUps": follow_ups,
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
pass
|
||
|
||
# === 任务 2 & 3: 标题和标签生成(仅非临时聊天)===
|
||
if not metadata.get("chat_id", "").startswith("local:"):
|
||
# 任务 2: 标题生成
|
||
if (
|
||
TASKS.TITLE_GENERATION in tasks
|
||
and tasks[TASKS.TITLE_GENERATION]
|
||
):
|
||
user_message = get_last_user_message(messages)
|
||
if user_message and len(user_message) > 100:
|
||
user_message = user_message[:100] + "..."
|
||
|
||
if tasks[TASKS.TITLE_GENERATION]:
|
||
|
||
res = await generate_title(
|
||
request,
|
||
{
|
||
"model": message["model"],
|
||
"messages": messages,
|
||
"chat_id": metadata["chat_id"],
|
||
},
|
||
user,
|
||
)
|
||
|
||
if res and isinstance(res, dict):
|
||
if len(res.get("choices", [])) == 1:
|
||
response_message = res.get("choices", [])[0].get(
|
||
"message", {}
|
||
)
|
||
|
||
title_string = (
|
||
response_message.get("content")
|
||
or response_message.get(
|
||
"reasoning_content",
|
||
)
|
||
or message.get("content", user_message)
|
||
)
|
||
else:
|
||
title_string = ""
|
||
|
||
# 提取 JSON 对象
|
||
title_string = title_string[
|
||
title_string.find("{") : title_string.rfind("}") + 1
|
||
]
|
||
|
||
try:
|
||
title = json.loads(title_string).get(
|
||
"title", user_message
|
||
)
|
||
except Exception as e:
|
||
title = ""
|
||
|
||
if not title:
|
||
title = messages[0].get("content", user_message)
|
||
|
||
# 更新数据库
|
||
Chats.update_chat_title_by_id(
|
||
metadata["chat_id"], title
|
||
)
|
||
|
||
# 通过 WebSocket 发送标题
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:title",
|
||
"data": title,
|
||
}
|
||
)
|
||
# 如果只有 2 条消息(首次对话),直接用用户消息作为标题
|
||
elif len(messages) == 2:
|
||
title = messages[0].get("content", user_message)
|
||
|
||
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:title",
|
||
"data": message.get("content", user_message),
|
||
}
|
||
)
|
||
|
||
# 任务 3: 标签生成
|
||
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
|
||
res = await generate_chat_tags(
|
||
request,
|
||
{
|
||
"model": message["model"],
|
||
"messages": messages,
|
||
"chat_id": metadata["chat_id"],
|
||
},
|
||
user,
|
||
)
|
||
|
||
if res and isinstance(res, dict):
|
||
if len(res.get("choices", [])) == 1:
|
||
response_message = res.get("choices", [])[0].get(
|
||
"message", {}
|
||
)
|
||
|
||
tags_string = response_message.get(
|
||
"content"
|
||
) or response_message.get("reasoning_content", "")
|
||
else:
|
||
tags_string = ""
|
||
|
||
# 提取 JSON 对象
|
||
tags_string = tags_string[
|
||
tags_string.find("{") : tags_string.rfind("}") + 1
|
||
]
|
||
|
||
try:
|
||
tags = json.loads(tags_string).get("tags", [])
|
||
# 更新数据库
|
||
Chats.update_chat_tags_by_id(
|
||
metadata["chat_id"], tags, user
|
||
)
|
||
|
||
# 通过 WebSocket 发送标签
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:tags",
|
||
"data": tags,
|
||
}
|
||
)
|
||
except Exception as e:
|
||
pass
|
||
|
||
# === 1. 解析事件发射器(仅在有实时会话时使用)===
|
||
event_emitter = None
|
||
event_caller = None
|
||
if (
|
||
"session_id" in metadata
|
||
and metadata["session_id"]
|
||
and "chat_id" in metadata
|
||
and metadata["chat_id"]
|
||
and "message_id" in metadata
|
||
and metadata["message_id"]
|
||
):
|
||
event_emitter = get_event_emitter(metadata) # WebSocket 事件发射器
|
||
event_caller = get_event_call(metadata) # 事件调用器
|
||
|
||
# === 2. 非流式响应处理 ===
|
||
if not isinstance(response, StreamingResponse):
|
||
if event_emitter:
|
||
try:
|
||
if isinstance(response, dict) or isinstance(response, JSONResponse):
|
||
# 处理单项列表(解包)
|
||
if isinstance(response, list) and len(response) == 1:
|
||
response = response[0]
|
||
|
||
# 解析 JSONResponse 的 body
|
||
if isinstance(response, JSONResponse) and isinstance(
|
||
response.body, bytes
|
||
):
|
||
try:
|
||
response_data = json.loads(
|
||
response.body.decode("utf-8", "replace")
|
||
)
|
||
except json.JSONDecodeError:
|
||
response_data = {
|
||
"error": {"detail": "Invalid JSON response"}
|
||
}
|
||
else:
|
||
response_data = response
|
||
|
||
# 处理错误响应
|
||
if "error" in response_data:
|
||
error = response_data.get("error")
|
||
|
||
if isinstance(error, dict):
|
||
error = error.get("detail", error)
|
||
else:
|
||
error = str(error)
|
||
|
||
# 保存错误到数据库
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"error": {"content": error},
|
||
},
|
||
)
|
||
# 通过 WebSocket 发送错误事件
|
||
if isinstance(error, str) or isinstance(error, dict):
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:message:error",
|
||
"data": {"error": {"content": error}},
|
||
}
|
||
)
|
||
|
||
if "selected_model_id" in response_data:
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"selectedModelId": response_data["selected_model_id"],
|
||
},
|
||
)
|
||
|
||
choices = response_data.get("choices", [])
|
||
if choices and choices[0].get("message", {}).get("content"):
|
||
content = response_data["choices"][0]["message"]["content"]
|
||
|
||
if content:
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": response_data,
|
||
}
|
||
)
|
||
|
||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": {
|
||
"done": True,
|
||
"content": content,
|
||
"title": title,
|
||
},
|
||
}
|
||
)
|
||
|
||
# Save message in the database
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"role": "assistant",
|
||
"content": content,
|
||
},
|
||
)
|
||
|
||
# Send a webhook notification if the user is not active
|
||
if not get_active_status_by_user_id(user.id):
|
||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||
if webhook_url:
|
||
await post_webhook(
|
||
request.app.state.WEBUI_NAME,
|
||
webhook_url,
|
||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||
{
|
||
"action": "chat",
|
||
"message": content,
|
||
"title": title,
|
||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||
},
|
||
)
|
||
|
||
await background_tasks_handler()
|
||
|
||
if events and isinstance(events, list):
|
||
extra_response = {}
|
||
for event in events:
|
||
if isinstance(event, dict):
|
||
extra_response.update(event)
|
||
else:
|
||
extra_response[event] = True
|
||
|
||
response_data = {
|
||
**extra_response,
|
||
**response_data,
|
||
}
|
||
|
||
if isinstance(response, dict):
|
||
response = response_data
|
||
if isinstance(response, JSONResponse):
|
||
response = JSONResponse(
|
||
content=response_data,
|
||
headers=response.headers,
|
||
status_code=response.status_code,
|
||
)
|
||
|
||
except Exception as e:
|
||
log.debug(f"Error occurred while processing request: {e}")
|
||
pass
|
||
|
||
return response
|
||
else:
|
||
if events and isinstance(events, list) and isinstance(response, dict):
|
||
extra_response = {}
|
||
for event in events:
|
||
if isinstance(event, dict):
|
||
extra_response.update(event)
|
||
else:
|
||
extra_response[event] = True
|
||
|
||
response = {
|
||
**extra_response,
|
||
**response,
|
||
}
|
||
|
||
return response
|
||
|
||
# Non standard response (not SSE/ndjson)
|
||
if not any(
|
||
content_type in response.headers["Content-Type"]
|
||
for content_type in ["text/event-stream", "application/x-ndjson"]
|
||
):
|
||
return response
|
||
|
||
oauth_token = None
|
||
try:
|
||
if request.cookies.get("oauth_session_id", None):
|
||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||
user.id,
|
||
request.cookies.get("oauth_session_id", None),
|
||
)
|
||
except Exception as e:
|
||
log.error(f"Error getting OAuth token: {e}")
|
||
|
||
extra_params = {
|
||
"__event_emitter__": event_emitter,
|
||
"__event_call__": event_caller,
|
||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||
"__metadata__": metadata,
|
||
"__oauth_token__": oauth_token,
|
||
"__request__": request,
|
||
"__model__": model,
|
||
}
|
||
filter_functions = [
|
||
Functions.get_function_by_id(filter_id)
|
||
for filter_id in get_sorted_filter_ids(
|
||
request, model, metadata.get("filter_ids", [])
|
||
)
|
||
]
|
||
|
||
# Streaming response: consume upstream SSE/ndjson, forward deltas/events, handle tools
|
||
if event_emitter and event_caller:
|
||
task_id = str(uuid4()) # Create a unique task ID.
|
||
model_id = form_data.get("model", "")
|
||
|
||
def split_content_and_whitespace(content):
|
||
content_stripped = content.rstrip()
|
||
original_whitespace = (
|
||
content[len(content_stripped) :]
|
||
if len(content) > len(content_stripped)
|
||
else ""
|
||
)
|
||
return content_stripped, original_whitespace
|
||
|
||
def is_opening_code_block(content):
|
||
backtick_segments = content.split("```")
|
||
# Even number of segments means the last backticks are opening a new block
|
||
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
||
|
||
# Handle as a background task
|
||
async def response_handler(response, events):
|
||
def serialize_content_blocks(content_blocks, raw=False):
|
||
content = ""
|
||
|
||
for block in content_blocks:
|
||
if block["type"] == "text":
|
||
block_content = block["content"].strip()
|
||
if block_content:
|
||
content = f"{content}{block_content}\n"
|
||
elif block["type"] == "tool_calls":
|
||
attributes = block.get("attributes", {})
|
||
|
||
tool_calls = block.get("content", [])
|
||
results = block.get("results", [])
|
||
|
||
if content and not content.endswith("\n"):
|
||
content += "\n"
|
||
|
||
if results:
|
||
|
||
tool_calls_display_content = ""
|
||
for tool_call in tool_calls:
|
||
|
||
tool_call_id = tool_call.get("id", "")
|
||
tool_name = tool_call.get("function", {}).get(
|
||
"name", ""
|
||
)
|
||
tool_arguments = tool_call.get("function", {}).get(
|
||
"arguments", ""
|
||
)
|
||
|
||
tool_result = None
|
||
tool_result_files = None
|
||
for result in results:
|
||
if tool_call_id == result.get("tool_call_id", ""):
|
||
tool_result = result.get("content", None)
|
||
tool_result_files = result.get("files", None)
|
||
break
|
||
|
||
if tool_result is not None:
|
||
tool_result_embeds = result.get("embeds", "")
|
||
tool_calls_display_content = f'{tool_calls_display_content}<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result, ensure_ascii=False))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}" embeds="{html.escape(json.dumps(tool_result_embeds))}">\n<summary>Tool Executed</summary>\n</details>\n'
|
||
else:
|
||
tool_calls_display_content = f'{tool_calls_display_content}<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>\n'
|
||
|
||
if not raw:
|
||
content = f"{content}{tool_calls_display_content}"
|
||
else:
|
||
tool_calls_display_content = ""
|
||
|
||
for tool_call in tool_calls:
|
||
tool_call_id = tool_call.get("id", "")
|
||
tool_name = tool_call.get("function", {}).get(
|
||
"name", ""
|
||
)
|
||
tool_arguments = tool_call.get("function", {}).get(
|
||
"arguments", ""
|
||
)
|
||
|
||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>\n'
|
||
|
||
if not raw:
|
||
content = f"{content}{tool_calls_display_content}"
|
||
|
||
elif block["type"] == "reasoning":
|
||
reasoning_display_content = "\n".join(
|
||
(f"> {line}" if not line.startswith(">") else line)
|
||
for line in block["content"].splitlines()
|
||
)
|
||
|
||
reasoning_duration = block.get("duration", None)
|
||
|
||
start_tag = block.get("start_tag", "")
|
||
end_tag = block.get("end_tag", "")
|
||
|
||
if content and not content.endswith("\n"):
|
||
content += "\n"
|
||
|
||
if reasoning_duration is not None:
|
||
if raw:
|
||
content = (
|
||
f'{content}{start_tag}{block["content"]}{end_tag}\n'
|
||
)
|
||
else:
|
||
content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||
else:
|
||
if raw:
|
||
content = (
|
||
f'{content}{start_tag}{block["content"]}{end_tag}\n'
|
||
)
|
||
else:
|
||
content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||
|
||
elif block["type"] == "code_interpreter":
|
||
attributes = block.get("attributes", {})
|
||
output = block.get("output", None)
|
||
lang = attributes.get("lang", "")
|
||
|
||
content_stripped, original_whitespace = (
|
||
split_content_and_whitespace(content)
|
||
)
|
||
if is_opening_code_block(content_stripped):
|
||
# Remove trailing backticks that would open a new block
|
||
content = (
|
||
content_stripped.rstrip("`").rstrip()
|
||
+ original_whitespace
|
||
)
|
||
else:
|
||
# Keep content as is - either closing backticks or no backticks
|
||
content = content_stripped + original_whitespace
|
||
|
||
if content and not content.endswith("\n"):
|
||
content += "\n"
|
||
|
||
if output:
|
||
output = html.escape(json.dumps(output))
|
||
|
||
if raw:
|
||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||
else:
|
||
content = f'{content}<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||
else:
|
||
if raw:
|
||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||
else:
|
||
content = f'{content}<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||
|
||
else:
|
||
block_content = str(block["content"]).strip()
|
||
if block_content:
|
||
content = f"{content}{block['type']}: {block_content}\n"
|
||
|
||
return content.strip()
|
||
|
||
def convert_content_blocks_to_messages(content_blocks, raw=False):
|
||
messages = []
|
||
|
||
temp_blocks = []
|
||
for idx, block in enumerate(content_blocks):
|
||
if block["type"] == "tool_calls":
|
||
messages.append(
|
||
{
|
||
"role": "assistant",
|
||
"content": serialize_content_blocks(temp_blocks, raw),
|
||
"tool_calls": block.get("content"),
|
||
}
|
||
)
|
||
|
||
results = block.get("results", [])
|
||
|
||
for result in results:
|
||
messages.append(
|
||
{
|
||
"role": "tool",
|
||
"tool_call_id": result["tool_call_id"],
|
||
"content": result.get("content", "") or "",
|
||
}
|
||
)
|
||
temp_blocks = []
|
||
else:
|
||
temp_blocks.append(block)
|
||
|
||
if temp_blocks:
|
||
content = serialize_content_blocks(temp_blocks, raw)
|
||
if content:
|
||
messages.append(
|
||
{
|
||
"role": "assistant",
|
||
"content": content,
|
||
}
|
||
)
|
||
|
||
return messages
|
||
|
||
def tag_content_handler(content_type, tags, content, content_blocks):
|
||
end_flag = False
|
||
|
||
def extract_attributes(tag_content):
|
||
"""Extract attributes from a tag if they exist."""
|
||
attributes = {}
|
||
if not tag_content: # Ensure tag_content is not None
|
||
return attributes
|
||
# Match attributes in the format: key="value" (ignores single quotes for simplicity)
|
||
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
|
||
for key, value in matches:
|
||
attributes[key] = value
|
||
return attributes
|
||
|
||
if content_blocks[-1]["type"] == "text":
|
||
for start_tag, end_tag in tags:
|
||
|
||
start_tag_pattern = rf"{re.escape(start_tag)}"
|
||
if start_tag.startswith("<") and start_tag.endswith(">"):
|
||
# Match start tag e.g., <tag> or <tag attr="value">
|
||
# remove both '<' and '>' from start_tag
|
||
# Match start tag with attributes
|
||
start_tag_pattern = (
|
||
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
|
||
)
|
||
|
||
match = re.search(start_tag_pattern, content)
|
||
if match:
|
||
try:
|
||
attr_content = (
|
||
match.group(1) if match.group(1) else ""
|
||
) # Ensure it's not None
|
||
except:
|
||
attr_content = ""
|
||
|
||
attributes = extract_attributes(
|
||
attr_content
|
||
) # Extract attributes safely
|
||
|
||
# Capture everything before and after the matched tag
|
||
before_tag = content[
|
||
: match.start()
|
||
] # Content before opening tag
|
||
after_tag = content[
|
||
match.end() :
|
||
] # Content after opening tag
|
||
|
||
# Remove the start tag and after from the currently handling text block
|
||
content_blocks[-1]["content"] = content_blocks[-1][
|
||
"content"
|
||
].replace(match.group(0) + after_tag, "")
|
||
|
||
if before_tag:
|
||
content_blocks[-1]["content"] = before_tag
|
||
|
||
if not content_blocks[-1]["content"]:
|
||
content_blocks.pop()
|
||
|
||
# Append the new block
|
||
content_blocks.append(
|
||
{
|
||
"type": content_type,
|
||
"start_tag": start_tag,
|
||
"end_tag": end_tag,
|
||
"attributes": attributes,
|
||
"content": "",
|
||
"started_at": time.time(),
|
||
}
|
||
)
|
||
|
||
if after_tag:
|
||
content_blocks[-1]["content"] = after_tag
|
||
tag_content_handler(
|
||
content_type, tags, after_tag, content_blocks
|
||
)
|
||
|
||
break
|
||
elif content_blocks[-1]["type"] == content_type:
|
||
start_tag = content_blocks[-1]["start_tag"]
|
||
end_tag = content_blocks[-1]["end_tag"]
|
||
|
||
if end_tag.startswith("<") and end_tag.endswith(">"):
|
||
# Match end tag e.g., </tag>
|
||
end_tag_pattern = rf"{re.escape(end_tag)}"
|
||
else:
|
||
# Handle cases where end_tag is just a tag name
|
||
end_tag_pattern = rf"{re.escape(end_tag)}"
|
||
|
||
# Check if the content has the end tag
|
||
if re.search(end_tag_pattern, content):
|
||
end_flag = True
|
||
|
||
block_content = content_blocks[-1]["content"]
|
||
# Strip start and end tags from the content
|
||
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>"
|
||
block_content = re.sub(
|
||
start_tag_pattern, "", block_content
|
||
).strip()
|
||
|
||
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL)
|
||
split_content = end_tag_regex.split(block_content, maxsplit=1)
|
||
|
||
# Content inside the tag
|
||
block_content = (
|
||
split_content[0].strip() if split_content else ""
|
||
)
|
||
|
||
# Leftover content (everything after `</tag>`)
|
||
leftover_content = (
|
||
split_content[1].strip() if len(split_content) > 1 else ""
|
||
)
|
||
|
||
if block_content:
|
||
content_blocks[-1]["content"] = block_content
|
||
content_blocks[-1]["ended_at"] = time.time()
|
||
content_blocks[-1]["duration"] = int(
|
||
content_blocks[-1]["ended_at"]
|
||
- content_blocks[-1]["started_at"]
|
||
)
|
||
|
||
# Reset the content_blocks by appending a new text block
|
||
if content_type != "code_interpreter":
|
||
if leftover_content:
|
||
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": leftover_content,
|
||
}
|
||
)
|
||
else:
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": "",
|
||
}
|
||
)
|
||
|
||
else:
|
||
# Remove the block if content is empty
|
||
content_blocks.pop()
|
||
|
||
if leftover_content:
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": leftover_content,
|
||
}
|
||
)
|
||
else:
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": "",
|
||
}
|
||
)
|
||
|
||
# Clean processed content
|
||
start_tag_pattern = rf"{re.escape(start_tag)}"
|
||
if start_tag.startswith("<") and start_tag.endswith(">"):
|
||
# Match start tag e.g., <tag> or <tag attr="value">
|
||
# remove both '<' and '>' from start_tag
|
||
# Match start tag with attributes
|
||
start_tag_pattern = (
|
||
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
|
||
)
|
||
|
||
content = re.sub(
|
||
rf"{start_tag_pattern}(.|\n)*?{re.escape(end_tag)}",
|
||
"",
|
||
content,
|
||
flags=re.DOTALL,
|
||
)
|
||
|
||
return content, content_blocks, end_flag
|
||
|
||
message = Chats.get_message_by_id_and_message_id(
|
||
metadata["chat_id"], metadata["message_id"]
|
||
)
|
||
|
||
tool_calls = []
|
||
|
||
last_assistant_message = None
|
||
try:
|
||
if form_data["messages"][-1]["role"] == "assistant":
|
||
last_assistant_message = get_last_assistant_message(
|
||
form_data["messages"]
|
||
)
|
||
except Exception as e:
|
||
pass
|
||
|
||
content = (
|
||
message.get("content", "")
|
||
if message
|
||
else last_assistant_message if last_assistant_message else ""
|
||
)
|
||
|
||
content_blocks = [
|
||
{
|
||
"type": "text",
|
||
"content": content,
|
||
}
|
||
]
|
||
|
||
reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
|
||
DETECT_REASONING_TAGS = reasoning_tags_param is not False
|
||
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
|
||
"code_interpreter", False
|
||
)
|
||
|
||
reasoning_tags = []
|
||
if DETECT_REASONING_TAGS:
|
||
if (
|
||
isinstance(reasoning_tags_param, list)
|
||
and len(reasoning_tags_param) == 2
|
||
):
|
||
reasoning_tags = [
|
||
(reasoning_tags_param[0], reasoning_tags_param[1])
|
||
]
|
||
else:
|
||
reasoning_tags = DEFAULT_REASONING_TAGS
|
||
|
||
try:
|
||
for event in events:
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": event,
|
||
}
|
||
)
|
||
|
||
# Save message in the database
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
**event,
|
||
},
|
||
)
|
||
|
||
async def stream_body_handler(response, form_data):
|
||
"""
|
||
流式响应体处理器 - 消费上游 SSE 流并处理内容块
|
||
|
||
核心功能:
|
||
1. 解析 SSE (Server-Sent Events) 流
|
||
2. 处理工具调用(tool_calls)增量更新
|
||
3. 处理推理内容(reasoning_content)- 如 <think> 标签
|
||
4. 处理解决方案(solution)和代码解释器(code_interpreter)标签
|
||
5. 实时保存消息到数据库(可选)
|
||
6. 控制流式推送频率(delta throttling)避免 WebSocket 过载
|
||
|
||
Args:
|
||
response: 上游 StreamingResponse 对象
|
||
form_data: 原始请求数据
|
||
"""
|
||
nonlocal content # 累积的完整文本内容
|
||
nonlocal content_blocks # 内容块列表(text/reasoning/code_interpreter)
|
||
|
||
response_tool_calls = [] # 累积的工具调用列表
|
||
|
||
# === 1. 初始化流式推送控制 ===
|
||
delta_count = 0 # 当前累积的 delta 数量
|
||
delta_chunk_size = max(
|
||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, # 全局配置的块大小
|
||
int(
|
||
metadata.get("params", {}).get("stream_delta_chunk_size")
|
||
or 1
|
||
), # 用户配置的块大小
|
||
)
|
||
last_delta_data = None # 待发送的最后一个 delta 数据
|
||
|
||
async def flush_pending_delta_data(threshold: int = 0):
|
||
"""
|
||
刷新待发送的 delta 数据
|
||
|
||
Args:
|
||
threshold: 阈值,当 delta_count >= threshold 时才发送
|
||
"""
|
||
nonlocal delta_count
|
||
nonlocal last_delta_data
|
||
|
||
if delta_count >= threshold and last_delta_data:
|
||
# 通过 WebSocket 发送累积的 delta
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": last_delta_data,
|
||
}
|
||
)
|
||
delta_count = 0
|
||
last_delta_data = None
|
||
|
||
# === 2. 消费 SSE 流 ===
|
||
async for line in response.body_iterator:
|
||
# 解码字节流为字符串
|
||
line = (
|
||
line.decode("utf-8", "replace")
|
||
if isinstance(line, bytes)
|
||
else line
|
||
)
|
||
data = line
|
||
|
||
# 跳过空行
|
||
if not data.strip():
|
||
continue
|
||
|
||
# SSE 格式:每个事件以 "data:" 开头
|
||
if not data.startswith("data:"):
|
||
continue
|
||
|
||
# 移除 "data:" 前缀
|
||
data = data[len("data:") :].strip()
|
||
|
||
try:
|
||
# 解析 JSON 数据
|
||
data = json.loads(data)
|
||
|
||
# === 3. 执行 Filter 函数(stream 类型)===
|
||
data, _ = await process_filter_functions(
|
||
request=request,
|
||
filter_functions=filter_functions,
|
||
filter_type="stream",
|
||
form_data=data,
|
||
extra_params={"__body__": form_data, **extra_params},
|
||
)
|
||
|
||
if data:
|
||
# 处理自定义事件
|
||
if "event" in data:
|
||
await event_emitter(data.get("event", {}))
|
||
|
||
# === 4. 处理 Arena 模式的模型选择 ===
|
||
if "selected_model_id" in data:
|
||
model_id = data["selected_model_id"]
|
||
# 保存选中的模型 ID 到数据库
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"selectedModelId": model_id,
|
||
},
|
||
)
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": data,
|
||
}
|
||
)
|
||
else:
|
||
choices = data.get("choices", [])
|
||
|
||
# === 5. 处理 usage 和 timings 信息 ===
|
||
usage = data.get("usage", {}) or {}
|
||
usage.update(data.get("timings", {})) # llama.cpp timings
|
||
if usage:
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": {
|
||
"usage": usage,
|
||
},
|
||
}
|
||
)
|
||
|
||
# === 6. 处理错误响应 ===
|
||
if not choices:
|
||
error = data.get("error", {})
|
||
if error:
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": {
|
||
"error": error,
|
||
},
|
||
}
|
||
)
|
||
continue
|
||
|
||
# === 7. 提取 delta(增量内容)===
|
||
delta = choices[0].get("delta", {})
|
||
delta_tool_calls = delta.get("tool_calls", None)
|
||
|
||
# === 8. 处理工具调用(Tool Calls)===
|
||
if delta_tool_calls:
|
||
for delta_tool_call in delta_tool_calls:
|
||
tool_call_index = delta_tool_call.get(
|
||
"index"
|
||
)
|
||
|
||
if tool_call_index is not None:
|
||
# 查找已存在的工具调用
|
||
current_response_tool_call = None
|
||
for (
|
||
response_tool_call
|
||
) in response_tool_calls:
|
||
if (
|
||
response_tool_call.get("index")
|
||
== tool_call_index
|
||
):
|
||
current_response_tool_call = (
|
||
response_tool_call
|
||
)
|
||
break
|
||
|
||
if current_response_tool_call is None:
|
||
# 添加新的工具调用
|
||
delta_tool_call.setdefault(
|
||
"function", {}
|
||
)
|
||
delta_tool_call[
|
||
"function"
|
||
].setdefault("name", "")
|
||
delta_tool_call[
|
||
"function"
|
||
].setdefault("arguments", "")
|
||
response_tool_calls.append(
|
||
delta_tool_call
|
||
)
|
||
else:
|
||
# 更新已存在的工具调用(累积 name 和 arguments)
|
||
delta_name = delta_tool_call.get(
|
||
"function", {}
|
||
).get("name")
|
||
delta_arguments = (
|
||
delta_tool_call.get(
|
||
"function", {}
|
||
).get("arguments")
|
||
)
|
||
|
||
if delta_name:
|
||
current_response_tool_call[
|
||
"function"
|
||
]["name"] += delta_name
|
||
|
||
if delta_arguments:
|
||
current_response_tool_call[
|
||
"function"
|
||
][
|
||
"arguments"
|
||
] += delta_arguments
|
||
|
||
# === 9. 处理文本内容 ===
|
||
value = delta.get("content")
|
||
|
||
# === 10. 处理推理内容(Reasoning Content)===
|
||
reasoning_content = (
|
||
delta.get("reasoning_content")
|
||
or delta.get("reasoning")
|
||
or delta.get("thinking")
|
||
)
|
||
if reasoning_content:
|
||
# 创建或更新 reasoning 内容块
|
||
if (
|
||
not content_blocks
|
||
or content_blocks[-1]["type"] != "reasoning"
|
||
):
|
||
reasoning_block = {
|
||
"type": "reasoning",
|
||
"start_tag": "<think>",
|
||
"end_tag": "</think>",
|
||
"attributes": {
|
||
"type": "reasoning_content"
|
||
},
|
||
"content": "",
|
||
"started_at": time.time(),
|
||
}
|
||
content_blocks.append(reasoning_block)
|
||
else:
|
||
reasoning_block = content_blocks[-1]
|
||
|
||
# 累积推理内容
|
||
reasoning_block["content"] += reasoning_content
|
||
|
||
data = {
|
||
"content": serialize_content_blocks(
|
||
content_blocks
|
||
)
|
||
}
|
||
|
||
# === 11. 处理普通文本内容 ===
|
||
if value:
|
||
# 如果上一个块是 reasoning,标记结束并创建新的文本块
|
||
if (
|
||
content_blocks
|
||
and content_blocks[-1]["type"]
|
||
== "reasoning"
|
||
and content_blocks[-1]
|
||
.get("attributes", {})
|
||
.get("type")
|
||
== "reasoning_content"
|
||
):
|
||
reasoning_block = content_blocks[-1]
|
||
reasoning_block["ended_at"] = time.time()
|
||
reasoning_block["duration"] = int(
|
||
reasoning_block["ended_at"]
|
||
- reasoning_block["started_at"]
|
||
)
|
||
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": "",
|
||
}
|
||
)
|
||
|
||
# 累积文本内容
|
||
content = f"{content}{value}"
|
||
if not content_blocks:
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": "",
|
||
}
|
||
)
|
||
|
||
content_blocks[-1]["content"] = (
|
||
content_blocks[-1]["content"] + value
|
||
)
|
||
|
||
# === 12. 检测并处理特殊标签 ===
|
||
# 12.1 Reasoning 标签检测(如 <think>)
|
||
if DETECT_REASONING_TAGS:
|
||
content, content_blocks, _ = (
|
||
tag_content_handler(
|
||
"reasoning",
|
||
reasoning_tags,
|
||
content,
|
||
content_blocks,
|
||
)
|
||
)
|
||
|
||
# Solution 标签检测
|
||
content, content_blocks, _ = (
|
||
tag_content_handler(
|
||
"solution",
|
||
DEFAULT_SOLUTION_TAGS,
|
||
content,
|
||
content_blocks,
|
||
)
|
||
)
|
||
|
||
# 12.2 Code Interpreter 标签检测
|
||
if DETECT_CODE_INTERPRETER:
|
||
content, content_blocks, end = (
|
||
tag_content_handler(
|
||
"code_interpreter",
|
||
DEFAULT_CODE_INTERPRETER_TAGS,
|
||
content,
|
||
content_blocks,
|
||
)
|
||
)
|
||
|
||
# 如果检测到结束标签,停止流式处理
|
||
if end:
|
||
break
|
||
|
||
# === 13. 实时保存消息(可选)===
|
||
if ENABLE_REALTIME_CHAT_SAVE:
|
||
# 保存到数据库
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"content": serialize_content_blocks(
|
||
content_blocks
|
||
),
|
||
},
|
||
)
|
||
else:
|
||
# 准备待发送的数据
|
||
data = {
|
||
"content": serialize_content_blocks(
|
||
content_blocks
|
||
),
|
||
}
|
||
|
||
# === 14. 流式推送控制 ===
|
||
if delta:
|
||
delta_count += 1
|
||
last_delta_data = data
|
||
# 达到阈值时刷新
|
||
if delta_count >= delta_chunk_size:
|
||
await flush_pending_delta_data(delta_chunk_size)
|
||
else:
|
||
# 非 delta 数据立即发送
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": data,
|
||
}
|
||
)
|
||
except Exception as e:
|
||
# 处理流结束标记
|
||
done = "data: [DONE]" in line
|
||
if done:
|
||
pass
|
||
else:
|
||
log.debug(f"Error: {e}")
|
||
continue
|
||
|
||
# === 15. 刷新剩余的 delta 数据 ===
|
||
await flush_pending_delta_data()
|
||
|
||
# === 16. 清理内容块 ===
|
||
if content_blocks:
|
||
# 清理最后一个文本块(移除尾部空白)
|
||
if content_blocks[-1]["type"] == "text":
|
||
content_blocks[-1]["content"] = content_blocks[-1][
|
||
"content"
|
||
].strip()
|
||
|
||
# 如果为空则移除
|
||
if not content_blocks[-1]["content"]:
|
||
content_blocks.pop()
|
||
|
||
# 确保至少有一个空文本块
|
||
if not content_blocks:
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": "",
|
||
}
|
||
)
|
||
|
||
# 标记最后一个 reasoning 块结束
|
||
if content_blocks[-1]["type"] == "reasoning":
|
||
reasoning_block = content_blocks[-1]
|
||
if reasoning_block.get("ended_at") is None:
|
||
reasoning_block["ended_at"] = time.time()
|
||
reasoning_block["duration"] = int(
|
||
reasoning_block["ended_at"]
|
||
- reasoning_block["started_at"]
|
||
)
|
||
|
||
# === 17. 保存工具调用 ===
|
||
if response_tool_calls:
|
||
tool_calls.append(response_tool_calls)
|
||
|
||
# === 18. 执行响应清理(关闭连接)===
|
||
if response.background:
|
||
await response.background()
|
||
|
||
await stream_body_handler(response, form_data)
|
||
|
||
tool_call_retries = 0
|
||
|
||
while (
|
||
len(tool_calls) > 0
|
||
and tool_call_retries < CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES
|
||
):
|
||
|
||
tool_call_retries += 1
|
||
|
||
response_tool_calls = tool_calls.pop(0)
|
||
|
||
content_blocks.append(
|
||
{
|
||
"type": "tool_calls",
|
||
"content": response_tool_calls,
|
||
}
|
||
)
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": {
|
||
"content": serialize_content_blocks(content_blocks),
|
||
},
|
||
}
|
||
)
|
||
|
||
tools = metadata.get("tools", {})
|
||
|
||
results = []
|
||
|
||
for tool_call in response_tool_calls:
|
||
tool_call_id = tool_call.get("id", "")
|
||
tool_function_name = tool_call.get("function", {}).get(
|
||
"name", ""
|
||
)
|
||
tool_args = tool_call.get("function", {}).get("arguments", "{}")
|
||
|
||
tool_function_params = {}
|
||
try:
|
||
# json.loads cannot be used because some models do not produce valid JSON
|
||
tool_function_params = ast.literal_eval(tool_args)
|
||
except Exception as e:
|
||
log.debug(e)
|
||
# Fallback to JSON parsing
|
||
try:
|
||
tool_function_params = json.loads(tool_args)
|
||
except Exception as e:
|
||
log.error(
|
||
f"Error parsing tool call arguments: {tool_args}"
|
||
)
|
||
|
||
# Mutate the original tool call response params as they are passed back to the passed
|
||
# back to the LLM via the content blocks. If they are in a json block and are invalid json,
|
||
# this can cause downstream LLM integrations to fail (e.g. bedrock gateway) where response
|
||
# params are not valid json.
|
||
# Main case so far is no args = "" = invalid json.
|
||
log.debug(
|
||
f"Parsed args from {tool_args} to {tool_function_params}"
|
||
)
|
||
tool_call.setdefault("function", {})["arguments"] = json.dumps(
|
||
tool_function_params
|
||
)
|
||
|
||
tool_result = None
|
||
tool = None
|
||
tool_type = None
|
||
direct_tool = False
|
||
|
||
if tool_function_name in tools:
|
||
tool = tools[tool_function_name]
|
||
spec = tool.get("spec", {})
|
||
|
||
tool_type = tool.get("type", "")
|
||
direct_tool = tool.get("direct", False)
|
||
|
||
try:
|
||
allowed_params = (
|
||
spec.get("parameters", {})
|
||
.get("properties", {})
|
||
.keys()
|
||
)
|
||
|
||
tool_function_params = {
|
||
k: v
|
||
for k, v in tool_function_params.items()
|
||
if k in allowed_params
|
||
}
|
||
|
||
if direct_tool:
|
||
tool_result = await event_caller(
|
||
{
|
||
"type": "execute:tool",
|
||
"data": {
|
||
"id": str(uuid4()),
|
||
"name": tool_function_name,
|
||
"params": tool_function_params,
|
||
"server": tool.get("server", {}),
|
||
"session_id": metadata.get(
|
||
"session_id", None
|
||
),
|
||
},
|
||
}
|
||
)
|
||
|
||
else:
|
||
tool_function = tool["callable"]
|
||
tool_result = await tool_function(
|
||
**tool_function_params
|
||
)
|
||
|
||
except Exception as e:
|
||
tool_result = str(e)
|
||
|
||
tool_result, tool_result_files, tool_result_embeds = (
|
||
process_tool_result(
|
||
request,
|
||
tool_function_name,
|
||
tool_result,
|
||
tool_type,
|
||
direct_tool,
|
||
metadata,
|
||
user,
|
||
)
|
||
)
|
||
|
||
results.append(
|
||
{
|
||
"tool_call_id": tool_call_id,
|
||
"content": tool_result or "",
|
||
**(
|
||
{"files": tool_result_files}
|
||
if tool_result_files
|
||
else {}
|
||
),
|
||
**(
|
||
{"embeds": tool_result_embeds}
|
||
if tool_result_embeds
|
||
else {}
|
||
),
|
||
}
|
||
)
|
||
|
||
content_blocks[-1]["results"] = results
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": "",
|
||
}
|
||
)
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": {
|
||
"content": serialize_content_blocks(content_blocks),
|
||
},
|
||
}
|
||
)
|
||
|
||
try:
|
||
new_form_data = {
|
||
**form_data,
|
||
"model": model_id,
|
||
"stream": True,
|
||
"messages": [
|
||
*form_data["messages"],
|
||
*convert_content_blocks_to_messages(
|
||
content_blocks, True
|
||
),
|
||
],
|
||
}
|
||
|
||
res = await generate_chat_completion(
|
||
request,
|
||
new_form_data,
|
||
user,
|
||
)
|
||
|
||
if isinstance(res, StreamingResponse):
|
||
await stream_body_handler(res, new_form_data)
|
||
else:
|
||
break
|
||
except Exception as e:
|
||
log.debug(e)
|
||
break
|
||
|
||
if DETECT_CODE_INTERPRETER:
|
||
MAX_RETRIES = 5
|
||
retries = 0
|
||
|
||
while (
|
||
content_blocks[-1]["type"] == "code_interpreter"
|
||
and retries < MAX_RETRIES
|
||
):
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": {
|
||
"content": serialize_content_blocks(content_blocks),
|
||
},
|
||
}
|
||
)
|
||
|
||
retries += 1
|
||
log.debug(f"Attempt count: {retries}")
|
||
|
||
output = ""
|
||
try:
|
||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||
code = content_blocks[-1]["content"]
|
||
if CODE_INTERPRETER_BLOCKED_MODULES:
|
||
blocking_code = textwrap.dedent(
|
||
f"""
|
||
import builtins
|
||
|
||
BLOCKED_MODULES = {CODE_INTERPRETER_BLOCKED_MODULES}
|
||
|
||
_real_import = builtins.__import__
|
||
def restricted_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||
if name.split('.')[0] in BLOCKED_MODULES:
|
||
importer_name = globals.get('__name__') if globals else None
|
||
if importer_name == '__main__':
|
||
raise ImportError(
|
||
f"Direct import of module {{name}} is restricted."
|
||
)
|
||
return _real_import(name, globals, locals, fromlist, level)
|
||
|
||
builtins.__import__ = restricted_import
|
||
"""
|
||
)
|
||
code = blocking_code + "\n" + code
|
||
|
||
if (
|
||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||
== "pyodide"
|
||
):
|
||
output = await event_caller(
|
||
{
|
||
"type": "execute:python",
|
||
"data": {
|
||
"id": str(uuid4()),
|
||
"code": code,
|
||
"session_id": metadata.get(
|
||
"session_id", None
|
||
),
|
||
},
|
||
}
|
||
)
|
||
elif (
|
||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||
== "jupyter"
|
||
):
|
||
output = await execute_code_jupyter(
|
||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||
code,
|
||
(
|
||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||
== "token"
|
||
else None
|
||
),
|
||
(
|
||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||
== "password"
|
||
else None
|
||
),
|
||
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||
)
|
||
else:
|
||
output = {
|
||
"stdout": "Code interpreter engine not configured."
|
||
}
|
||
|
||
log.debug(f"Code interpreter output: {output}")
|
||
|
||
if isinstance(output, dict):
|
||
stdout = output.get("stdout", "")
|
||
|
||
if isinstance(stdout, str):
|
||
stdoutLines = stdout.split("\n")
|
||
for idx, line in enumerate(stdoutLines):
|
||
|
||
if "data:image/png;base64" in line:
|
||
image_url = get_image_url_from_base64(
|
||
request,
|
||
line,
|
||
metadata,
|
||
user,
|
||
)
|
||
if image_url:
|
||
stdoutLines[idx] = (
|
||
f""
|
||
)
|
||
|
||
output["stdout"] = "\n".join(stdoutLines)
|
||
|
||
result = output.get("result", "")
|
||
|
||
if isinstance(result, str):
|
||
resultLines = result.split("\n")
|
||
for idx, line in enumerate(resultLines):
|
||
if "data:image/png;base64" in line:
|
||
image_url = get_image_url_from_base64(
|
||
request,
|
||
line,
|
||
metadata,
|
||
user,
|
||
)
|
||
resultLines[idx] = (
|
||
f""
|
||
)
|
||
output["result"] = "\n".join(resultLines)
|
||
except Exception as e:
|
||
output = str(e)
|
||
|
||
content_blocks[-1]["output"] = output
|
||
|
||
content_blocks.append(
|
||
{
|
||
"type": "text",
|
||
"content": "",
|
||
}
|
||
)
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": {
|
||
"content": serialize_content_blocks(content_blocks),
|
||
},
|
||
}
|
||
)
|
||
|
||
try:
|
||
new_form_data = {
|
||
**form_data,
|
||
"model": model_id,
|
||
"stream": True,
|
||
"messages": [
|
||
*form_data["messages"],
|
||
{
|
||
"role": "assistant",
|
||
"content": serialize_content_blocks(
|
||
content_blocks, raw=True
|
||
),
|
||
},
|
||
],
|
||
}
|
||
|
||
res = await generate_chat_completion(
|
||
request,
|
||
new_form_data,
|
||
user,
|
||
)
|
||
|
||
if isinstance(res, StreamingResponse):
|
||
await stream_body_handler(res, new_form_data)
|
||
else:
|
||
break
|
||
except Exception as e:
|
||
log.debug(e)
|
||
break
|
||
|
||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||
data = {
|
||
"done": True,
|
||
"content": serialize_content_blocks(content_blocks),
|
||
"title": title,
|
||
}
|
||
|
||
if not ENABLE_REALTIME_CHAT_SAVE:
|
||
# Save message in the database
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"content": serialize_content_blocks(content_blocks),
|
||
},
|
||
)
|
||
|
||
# Send a webhook notification if the user is not active
|
||
if not get_active_status_by_user_id(user.id):
|
||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||
if webhook_url:
|
||
await post_webhook(
|
||
request.app.state.WEBUI_NAME,
|
||
webhook_url,
|
||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||
{
|
||
"action": "chat",
|
||
"message": content,
|
||
"title": title,
|
||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||
},
|
||
)
|
||
|
||
await event_emitter(
|
||
{
|
||
"type": "chat:completion",
|
||
"data": data,
|
||
}
|
||
)
|
||
|
||
await background_tasks_handler()
|
||
except asyncio.CancelledError:
|
||
log.warning("Task was cancelled!")
|
||
await event_emitter({"type": "chat:tasks:cancel"})
|
||
|
||
if not ENABLE_REALTIME_CHAT_SAVE:
|
||
# Save message in the database
|
||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||
metadata["chat_id"],
|
||
metadata["message_id"],
|
||
{
|
||
"content": serialize_content_blocks(content_blocks),
|
||
},
|
||
)
|
||
|
||
if response.background is not None:
|
||
await response.background()
|
||
|
||
return await response_handler(response, events)
|
||
|
||
else:
|
||
# Fallback to the original response
|
||
async def stream_wrapper(original_generator, events):
|
||
def wrap_item(item):
|
||
return f"data: {item}\n\n"
|
||
|
||
for event in events:
|
||
event, _ = await process_filter_functions(
|
||
request=request,
|
||
filter_functions=filter_functions,
|
||
filter_type="stream",
|
||
form_data=event,
|
||
extra_params=extra_params,
|
||
)
|
||
|
||
if event:
|
||
yield wrap_item(json.dumps(event))
|
||
|
||
async for data in original_generator:
|
||
data, _ = await process_filter_functions(
|
||
request=request,
|
||
filter_functions=filter_functions,
|
||
filter_type="stream",
|
||
form_data=data,
|
||
extra_params=extra_params,
|
||
)
|
||
|
||
if data:
|
||
yield data
|
||
|
||
return StreamingResponse(
|
||
stream_wrapper(response.body_iterator, events),
|
||
headers=dict(response.headers),
|
||
background=response.background,
|
||
)
|