mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
- Fixed tasks running without proper metadata being included when those task models are used. This creates issue for example for Langfuse where title_generation or tag_generation and so on is not properly retrieved because it is never triggered as it should
1023 lines
35 KiB
Python
1023 lines
35 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
from uuid import uuid4
|
|
import logging
|
|
import re
|
|
|
|
from open_webui.utils.chat import generate_chat_completion
|
|
from open_webui.utils.task import (
|
|
title_generation_template,
|
|
follow_up_generation_template,
|
|
query_generation_template,
|
|
image_prompt_generation_template,
|
|
autocomplete_generation_template,
|
|
tags_generation_template,
|
|
emoji_generation_template,
|
|
moa_response_generation_template,
|
|
)
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.constants import TASKS
|
|
|
|
from open_webui.routers.pipelines import (
|
|
process_pipeline_inlet_filter,
|
|
process_pipeline_outlet_filter,
|
|
)
|
|
|
|
from open_webui.utils.task import get_task_model_id
|
|
|
|
from open_webui.utils.filter import get_sorted_filter_ids, process_filter_functions
|
|
from open_webui.models.functions import Functions
|
|
from copy import deepcopy
|
|
|
|
from open_webui.config import (
|
|
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE,
|
|
)
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def _attach_model_metadata(payload: dict, models: dict, model_id: str) -> None:
|
|
"""
|
|
Ensure payload.metadata carries the full model object so downstream filters
|
|
(e.g., Responses/Langfuse) can resolve model name/id for telemetry.
|
|
"""
|
|
try:
|
|
model_item = models.get(model_id)
|
|
if payload.get("metadata") is None:
|
|
payload["metadata"] = {}
|
|
if model_item and isinstance(payload["metadata"], dict):
|
|
payload["metadata"]["model"] = model_item
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _ensure_task_identifiers(payload: dict, form_data: dict) -> dict:
|
|
"""
|
|
Guarantee chat/session identifiers live in both root and metadata so filters
|
|
that rely on them (e.g., Langfuse) can create traces for task runs.
|
|
"""
|
|
metadata = payload.get("metadata") or {}
|
|
form_meta = form_data.get("metadata") or {}
|
|
|
|
chat_id = (
|
|
metadata.get("chat_id")
|
|
or form_data.get("chat_id")
|
|
or form_meta.get("chat_id")
|
|
or f"task-{uuid4()}"
|
|
)
|
|
session_id = (
|
|
metadata.get("session_id")
|
|
or form_data.get("session_id")
|
|
or form_meta.get("session_id")
|
|
or chat_id
|
|
)
|
|
message_id = metadata.get("message_id") or form_data.get("message_id")
|
|
|
|
if not isinstance(metadata, dict):
|
|
metadata = {}
|
|
metadata.setdefault("chat_id", chat_id)
|
|
metadata.setdefault("session_id", session_id)
|
|
if message_id:
|
|
metadata.setdefault("message_id", message_id)
|
|
|
|
payload["metadata"] = metadata
|
|
payload["chat_id"] = chat_id
|
|
payload["session_id"] = session_id
|
|
if message_id:
|
|
payload.setdefault("id", message_id)
|
|
|
|
return payload
|
|
|
|
|
|
async def _run_task_filters(
|
|
request: Request,
|
|
payload: dict,
|
|
user,
|
|
models: dict,
|
|
response: Optional[dict] = None,
|
|
):
|
|
"""
|
|
Run filter/pipeline outlet hooks for task payloads so telemetry filters (e.g. Langfuse)
|
|
capture non-chat generations such as titles, tags, or search queries.
|
|
"""
|
|
|
|
model_id = payload.get("model")
|
|
if not model_id or model_id not in models:
|
|
return payload
|
|
|
|
payload = _ensure_task_identifiers(payload, payload)
|
|
|
|
model = models[model_id]
|
|
metadata = payload.get("metadata") or {}
|
|
filter_ids = metadata.get("filter_ids", [])
|
|
|
|
filter_functions = [
|
|
Functions.get_function_by_id(filter_id)
|
|
for filter_id in get_sorted_filter_ids(request, model, filter_ids)
|
|
]
|
|
filter_functions = [f for f in filter_functions if f]
|
|
|
|
if not filter_functions:
|
|
return payload
|
|
|
|
async def _noop_event_emitter(event):
|
|
return None
|
|
|
|
async def _noop_event_call(event):
|
|
return None
|
|
|
|
extra_params = {
|
|
"__event_emitter__": _noop_event_emitter,
|
|
"__event_call__": _noop_event_call,
|
|
"__user__": user.model_dump() if hasattr(user, "model_dump") else {},
|
|
"__metadata__": metadata,
|
|
"__request__": request,
|
|
"__model__": model,
|
|
"__task__": metadata.get("task"),
|
|
"__task_body__": metadata.get("task_body"),
|
|
}
|
|
|
|
body = deepcopy(payload)
|
|
filter_type = "inlet"
|
|
|
|
if response is not None:
|
|
assistant_message = None
|
|
if isinstance(response, dict):
|
|
choices = response.get("choices", [])
|
|
if choices:
|
|
assistant_message = dict(choices[0].get("message") or {})
|
|
if response.get("usage"):
|
|
assistant_message["usage"] = response["usage"]
|
|
|
|
messages = [
|
|
dict(message)
|
|
for message in body.get("messages", [])
|
|
if isinstance(message, dict)
|
|
]
|
|
if assistant_message:
|
|
messages.append(assistant_message)
|
|
|
|
body = {
|
|
"model": body.get("model"),
|
|
"messages": messages,
|
|
"metadata": metadata,
|
|
"chat_id": metadata.get("chat_id"),
|
|
"session_id": metadata.get("session_id"),
|
|
**({"filter_ids": filter_ids} if filter_ids else {}),
|
|
}
|
|
|
|
try:
|
|
body = await process_pipeline_outlet_filter(request, body, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task pipeline outlet filter failed: {e}")
|
|
|
|
filter_type = "outlet"
|
|
|
|
try:
|
|
body, _ = await process_filter_functions(
|
|
request=request,
|
|
filter_functions=filter_functions,
|
|
filter_type=filter_type,
|
|
form_data=body,
|
|
extra_params=extra_params,
|
|
)
|
|
except Exception as e:
|
|
log.debug(f"Task filter {filter_type} failed: {e}")
|
|
|
|
return body if response is None else payload
|
|
|
|
|
|
##################################
|
|
#
|
|
# Task Endpoints
|
|
#
|
|
##################################
|
|
|
|
|
|
@router.get("/config")
|
|
async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
|
return {
|
|
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
|
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
|
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
|
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
|
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
|
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
|
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
|
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
|
}
|
|
|
|
|
|
class TaskConfigForm(BaseModel):
|
|
TASK_MODEL: Optional[str]
|
|
TASK_MODEL_EXTERNAL: Optional[str]
|
|
ENABLE_TITLE_GENERATION: bool
|
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
|
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
|
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
|
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
|
|
ENABLE_FOLLOW_UP_GENERATION: bool
|
|
ENABLE_TAGS_GENERATION: bool
|
|
ENABLE_SEARCH_QUERY_GENERATION: bool
|
|
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
|
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
|
VOICE_MODE_PROMPT_TEMPLATE: Optional[str]
|
|
|
|
|
|
@router.post("/config/update")
|
|
async def update_task_config(
|
|
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
|
|
):
|
|
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
|
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
|
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
|
|
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
|
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
|
|
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
|
|
form_data.ENABLE_FOLLOW_UP_GENERATION
|
|
)
|
|
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
|
|
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
|
|
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
|
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
|
|
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
|
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
|
)
|
|
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
|
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
|
)
|
|
|
|
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
|
|
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
|
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
|
|
form_data.ENABLE_SEARCH_QUERY_GENERATION
|
|
)
|
|
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
|
|
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
|
)
|
|
|
|
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
)
|
|
|
|
request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = (
|
|
form_data.VOICE_MODE_PROMPT_TEMPLATE
|
|
)
|
|
|
|
return {
|
|
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
|
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
|
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
|
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
|
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
|
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
|
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
|
}
|
|
|
|
|
|
@router.post("/title/completions")
|
|
async def generate_title(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
|
|
if not request.app.state.config.ENABLE_TITLE_GENERATION:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_200_OK,
|
|
content={"detail": "Title generation is disabled"},
|
|
)
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(
|
|
f"generating chat title using model {task_model_id} for user {user.email} "
|
|
)
|
|
|
|
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
|
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = title_generation_template(template, form_data["messages"], user)
|
|
|
|
max_tokens = (
|
|
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
|
|
)
|
|
|
|
payload = {
|
|
"model": task_model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
**(
|
|
{"max_tokens": max_tokens}
|
|
if models[task_model_id].get("owned_by") == "ollama"
|
|
else {
|
|
"max_completion_tokens": max_tokens,
|
|
}
|
|
),
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"task": str(TASKS.TITLE_GENERATION),
|
|
"task_body": form_data,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, task_model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
log.error("Exception occurred", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": "An internal error has occurred."},
|
|
)
|
|
|
|
|
|
@router.post("/follow_up/completions")
|
|
async def generate_follow_ups(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
|
|
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_200_OK,
|
|
content={"detail": "Follow-up generation is disabled"},
|
|
)
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(
|
|
f"generating chat title using model {task_model_id} for user {user.email} "
|
|
)
|
|
|
|
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
|
|
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = follow_up_generation_template(template, form_data["messages"], user)
|
|
|
|
payload = {
|
|
"model": task_model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"task": str(TASKS.FOLLOW_UP_GENERATION),
|
|
"task_body": form_data,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, task_model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
log.error("Exception occurred", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": "An internal error has occurred."},
|
|
)
|
|
|
|
|
|
@router.post("/tags/completions")
|
|
async def generate_chat_tags(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
|
|
if not request.app.state.config.ENABLE_TAGS_GENERATION:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_200_OK,
|
|
content={"detail": "Tags generation is disabled"},
|
|
)
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(
|
|
f"generating chat tags using model {task_model_id} for user {user.email} "
|
|
)
|
|
|
|
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
|
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = tags_generation_template(template, form_data["messages"], user)
|
|
|
|
payload = {
|
|
"model": task_model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"task": str(TASKS.TAGS_GENERATION),
|
|
"task_body": form_data,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, task_model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
log.error(f"Error generating chat completion: {e}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
content={"detail": "An internal error has occurred."},
|
|
)
|
|
|
|
|
|
@router.post("/image_prompt/completions")
|
|
async def generate_image_prompt(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(
|
|
f"generating image prompt using model {task_model_id} for user {user.email} "
|
|
)
|
|
|
|
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
|
|
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = image_prompt_generation_template(template, form_data["messages"], user)
|
|
|
|
payload = {
|
|
"model": task_model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
|
"task_body": form_data,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, task_model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
log.error("Exception occurred", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": "An internal error has occurred."},
|
|
)
|
|
|
|
|
|
@router.post("/queries/completions")
|
|
async def generate_queries(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
|
|
type = form_data.get("type")
|
|
if type == "web_search":
|
|
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Search query generation is disabled",
|
|
)
|
|
elif type == "retrieval":
|
|
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Query generation is disabled",
|
|
)
|
|
|
|
if getattr(request.state, "cached_queries", None):
|
|
log.info(f"Reusing cached queries: {request.state.cached_queries}")
|
|
return request.state.cached_queries
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(
|
|
f"generating {type} queries using model {task_model_id} for user {user.email}"
|
|
)
|
|
|
|
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
|
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = query_generation_template(template, form_data["messages"], user)
|
|
|
|
payload = {
|
|
"model": task_model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"task": str(TASKS.QUERY_GENERATION),
|
|
"task_body": form_data,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, task_model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": str(e)},
|
|
)
|
|
|
|
|
|
@router.post("/auto/completions")
|
|
async def generate_autocompletion(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Autocompletion generation is disabled",
|
|
)
|
|
|
|
type = form_data.get("type")
|
|
prompt = form_data.get("prompt")
|
|
messages = form_data.get("messages")
|
|
|
|
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
|
|
if (
|
|
len(prompt)
|
|
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
|
)
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(
|
|
f"generating autocompletion using model {task_model_id} for user {user.email}"
|
|
)
|
|
|
|
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
|
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = autocomplete_generation_template(template, prompt, messages, type, user)
|
|
|
|
payload = {
|
|
"model": task_model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
|
"task_body": form_data,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, task_model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
log.error(f"Error generating chat completion: {e}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
content={"detail": "An internal error has occurred."},
|
|
)
|
|
|
|
|
|
@router.post("/emoji/completions")
|
|
async def generate_emoji(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
|
|
|
|
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = emoji_generation_template(template, form_data["prompt"], user)
|
|
|
|
payload = {
|
|
"model": task_model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
**(
|
|
{"max_tokens": 4}
|
|
if models[task_model_id].get("owned_by") == "ollama"
|
|
else {
|
|
"max_completion_tokens": 4,
|
|
}
|
|
),
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"task": str(TASKS.EMOJI_GENERATION),
|
|
"task_body": form_data,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, task_model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": str(e)},
|
|
)
|
|
|
|
|
|
@router.post("/moa/completions")
|
|
async def generate_moa_response(
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
):
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
)
|
|
|
|
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = moa_response_generation_template(
|
|
template,
|
|
form_data["prompt"],
|
|
form_data["responses"],
|
|
)
|
|
|
|
payload = {
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": form_data.get("stream", False),
|
|
"metadata": {
|
|
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
"chat_id": form_data.get("chat_id", None),
|
|
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
|
"task_body": form_data,
|
|
},
|
|
}
|
|
|
|
payload = _ensure_task_identifiers(payload, form_data)
|
|
_attach_model_metadata(payload, models, model_id)
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
payload = await _run_task_filters(request, payload, user, models)
|
|
except Exception as e:
|
|
log.debug(f"Task inlet filter execution failed: {e}")
|
|
|
|
try:
|
|
res = await generate_chat_completion(request, form_data=payload, user=user)
|
|
if isinstance(res, dict):
|
|
try:
|
|
await _run_task_filters(request, payload, user, models, response=res)
|
|
except Exception as e:
|
|
log.debug(f"Task outlet filter execution failed: {e}")
|
|
return res
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": str(e)},
|
|
)
|