open-webui/backend/open_webui/routers/tasks.py
Yethe Samartaka 17ea99e758 fix: tasks run without proper metadata
- Fixed tasks running without proper metadata being included when those task models are used. This creates issue for example for Langfuse where title_generation or tag_generation and so on is not properly retrieved because it is never triggered as it should
2025-12-11 12:22:51 +01:00

1023 lines
35 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from typing import Optional
from uuid import uuid4
import logging
import re
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
follow_up_generation_template,
query_generation_template,
image_prompt_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
moa_response_generation_template,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
)
from open_webui.utils.task import get_task_model_id
from open_webui.utils.filter import get_sorted_filter_ids, process_filter_functions
from open_webui.models.functions import Functions
from copy import deepcopy
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
def _attach_model_metadata(payload: dict, models: dict, model_id: str) -> None:
"""
Ensure payload.metadata carries the full model object so downstream filters
(e.g., Responses/Langfuse) can resolve model name/id for telemetry.
"""
try:
model_item = models.get(model_id)
if payload.get("metadata") is None:
payload["metadata"] = {}
if model_item and isinstance(payload["metadata"], dict):
payload["metadata"]["model"] = model_item
except Exception:
pass
def _ensure_task_identifiers(payload: dict, form_data: dict) -> dict:
"""
Guarantee chat/session identifiers live in both root and metadata so filters
that rely on them (e.g., Langfuse) can create traces for task runs.
"""
metadata = payload.get("metadata") or {}
form_meta = form_data.get("metadata") or {}
chat_id = (
metadata.get("chat_id")
or form_data.get("chat_id")
or form_meta.get("chat_id")
or f"task-{uuid4()}"
)
session_id = (
metadata.get("session_id")
or form_data.get("session_id")
or form_meta.get("session_id")
or chat_id
)
message_id = metadata.get("message_id") or form_data.get("message_id")
if not isinstance(metadata, dict):
metadata = {}
metadata.setdefault("chat_id", chat_id)
metadata.setdefault("session_id", session_id)
if message_id:
metadata.setdefault("message_id", message_id)
payload["metadata"] = metadata
payload["chat_id"] = chat_id
payload["session_id"] = session_id
if message_id:
payload.setdefault("id", message_id)
return payload
async def _run_task_filters(
request: Request,
payload: dict,
user,
models: dict,
response: Optional[dict] = None,
):
"""
Run filter/pipeline outlet hooks for task payloads so telemetry filters (e.g. Langfuse)
capture non-chat generations such as titles, tags, or search queries.
"""
model_id = payload.get("model")
if not model_id or model_id not in models:
return payload
payload = _ensure_task_identifiers(payload, payload)
model = models[model_id]
metadata = payload.get("metadata") or {}
filter_ids = metadata.get("filter_ids", [])
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(request, model, filter_ids)
]
filter_functions = [f for f in filter_functions if f]
if not filter_functions:
return payload
async def _noop_event_emitter(event):
return None
async def _noop_event_call(event):
return None
extra_params = {
"__event_emitter__": _noop_event_emitter,
"__event_call__": _noop_event_call,
"__user__": user.model_dump() if hasattr(user, "model_dump") else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
"__task__": metadata.get("task"),
"__task_body__": metadata.get("task_body"),
}
body = deepcopy(payload)
filter_type = "inlet"
if response is not None:
assistant_message = None
if isinstance(response, dict):
choices = response.get("choices", [])
if choices:
assistant_message = dict(choices[0].get("message") or {})
if response.get("usage"):
assistant_message["usage"] = response["usage"]
messages = [
dict(message)
for message in body.get("messages", [])
if isinstance(message, dict)
]
if assistant_message:
messages.append(assistant_message)
body = {
"model": body.get("model"),
"messages": messages,
"metadata": metadata,
"chat_id": metadata.get("chat_id"),
"session_id": metadata.get("session_id"),
**({"filter_ids": filter_ids} if filter_ids else {}),
}
try:
body = await process_pipeline_outlet_filter(request, body, user, models)
except Exception as e:
log.debug(f"Task pipeline outlet filter failed: {e}")
filter_type = "outlet"
try:
body, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type=filter_type,
form_data=body,
extra_params=extra_params,
)
except Exception as e:
log.debug(f"Task filter {filter_type} failed: {e}")
return body if response is None else payload
##################################
#
# Task Endpoints
#
##################################
@router.get("/config")
async def get_task_config(request: Request, user=Depends(get_verified_user)):
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
ENABLE_TITLE_GENERATION: bool
TITLE_GENERATION_PROMPT_TEMPLATE: str
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
ENABLE_FOLLOW_UP_GENERATION: bool
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
VOICE_MODE_PROMPT_TEMPLATE: Optional[str]
@router.post("/config/update")
async def update_task_config(
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
form_data.ENABLE_FOLLOW_UP_GENERATION
)
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
form_data.ENABLE_AUTOCOMPLETE_GENERATION
)
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
)
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
form_data.ENABLE_SEARCH_QUERY_GENERATION
)
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
)
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = (
form_data.VOICE_MODE_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
}
@router.post("/title/completions")
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_TITLE_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Title generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(template, form_data["messages"], user)
max_tokens = (
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": max_tokens}
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": max_tokens,
}
),
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, task_model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/follow_up/completions")
async def generate_follow_ups(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Follow-up generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
content = follow_up_generation_template(template, form_data["messages"], user)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.FOLLOW_UP_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, task_model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/tags/completions")
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_TAGS_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Tags generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat tags using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
content = tags_generation_template(template, form_data["messages"], user)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, task_model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
log.error(f"Error generating chat completion: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
)
@router.post("/image_prompt/completions")
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating image prompt using model {task_model_id} for user {user.email} "
)
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
content = image_prompt_generation_template(template, form_data["messages"], user)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, task_model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/queries/completions")
async def generate_queries(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
type = form_data.get("type")
if type == "web_search":
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Search query generation is disabled",
)
elif type == "retrieval":
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Query generation is disabled",
)
if getattr(request.state, "cached_queries", None):
log.info(f"Reusing cached queries: {request.state.cached_queries}")
return request.state.cached_queries
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating {type} queries using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
content = query_generation_template(template, form_data["messages"], user)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, task_model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
@router.post("/auto/completions")
async def generate_autocompletion(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Autocompletion generation is disabled",
)
type = form_data.get("type")
prompt = form_data.get("prompt")
messages = form_data.get("messages")
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
if (
len(prompt)
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating autocompletion using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
content = autocomplete_generation_template(template, prompt, messages, type, user)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, task_model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
log.error(f"Error generating chat completion: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
)
@router.post("/emoji/completions")
async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
content = emoji_generation_template(template, form_data["prompt"], user)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 4,
}
),
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, task_model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
@router.post("/moa/completions")
async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
content = moa_response_generation_template(
template,
form_data["prompt"],
form_data["responses"],
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,
},
}
payload = _ensure_task_identifiers(payload, form_data)
_attach_model_metadata(payload, models, model_id)
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
payload = await _run_task_filters(request, payload, user, models)
except Exception as e:
log.debug(f"Task inlet filter execution failed: {e}")
try:
res = await generate_chat_completion(request, form_data=payload, user=user)
if isinstance(res, dict):
try:
await _run_task_filters(request, payload, user, models, response=res)
except Exception as e:
log.debug(f"Task outlet filter execution failed: {e}")
return res
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)