diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index e32fc1a171..df96344ab9 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -1055,6 +1055,32 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
{{MESSAGES:END:6}}
"""
+IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+ "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE",
+ "task.image.prompt_template",
+ os.environ.get("IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE", ""),
+)
+
+DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = """### Task:
+Generate a detailed prompt for am image generation task based on the given language and context. Describe the image as if you were explaining it to someone who cannot see it. Include relevant details, colors, shapes, and any other important elements.
+
+### Guidelines:
+- Be descriptive and detailed, focusing on the most important aspects of the image.
+- Avoid making assumptions or adding information not present in the image.
+- Use the chat's primary language; default to English if multilingual.
+- If the image is too complex, focus on the most prominent elements.
+
+### Output:
+Strictly return in JSON format:
+{
+ "prompt": "Your detailed description here."
+}
+
+### Chat History:
+
+{{MESSAGES:END:6}}
+"""
+
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",
diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py
index c5fdfabfb8..cb65e0d771 100644
--- a/backend/open_webui/constants.py
+++ b/backend/open_webui/constants.py
@@ -113,6 +113,7 @@ class TASKS(str, Enum):
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
+ IMAGE_PROMPT_GENERATION = "image_prompt_generation"
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 6414baccad..b13f957a5f 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -255,6 +255,7 @@ from open_webui.config import (
ENABLE_AUTOCOMPLETE_GENERATION,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
+ IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
QUERY_GENERATION_PROMPT_TEMPLATE,
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
@@ -644,6 +645,10 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
+app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
+ IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+)
+
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py
index 7d14a9d18d..6d7343c8a4 100644
--- a/backend/open_webui/routers/tasks.py
+++ b/backend/open_webui/routers/tasks.py
@@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
query_generation_template,
+ image_prompt_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
@@ -23,6 +24,7 @@ from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
@@ -50,6 +52,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+ "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@@ -65,6 +68,7 @@ class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
+ IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
@@ -114,6 +118,7 @@ async def update_task_config(
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+ "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
@@ -256,6 +261,66 @@ async def generate_chat_tags(
)
+@router.post("/image_prompt/completions")
+async def generate_image_prompt(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating image prompt using model {task_model_id} for user {user.email} "
+ )
+
+ if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
+ template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+
+ content = image_prompt_generation_template(
+ template,
+ form_data["messages"],
+ user={
+ "name": user.name,
+ },
+ )
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ "metadata": {
+ "task": str(TASKS.IMAGE_PROMPT_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ log.error("Exception occurred", exc_info=True)
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": "An internal error has occurred."},
+ )
+
+
@router.post("/queries/completions")
async def generate_queries(
request: Request, form_data: dict, user=Depends(get_verified_user)
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 221847d073..61704513df 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -28,6 +28,7 @@ from open_webui.socket.main import (
from open_webui.routers.tasks import (
generate_queries,
generate_title,
+ generate_image_prompt,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
@@ -503,12 +504,44 @@ async def chat_image_generation_handler(
messages = form_data["messages"]
user_message = get_last_user_message(messages)
+ prompt = ""
+ negative_prompt = ""
+
+ try:
+ res = await generate_image_prompt(
+ request,
+ {
+ "model": form_data["model"],
+ "messages": messages,
+ },
+ user,
+ )
+
+ response = res["choices"][0]["message"]["content"]
+
+ try:
+ bracket_start = response.find("{")
+ bracket_end = response.rfind("}") + 1
+
+ if bracket_start == -1 or bracket_end == -1:
+ raise Exception("No JSON object found in the response")
+
+ response = response[bracket_start:bracket_end]
+ response = json.loads(response)
+ prompt = response.get("prompt", [])
+ except Exception as e:
+ prompt = user_message
+
+ except Exception as e:
+ log.exception(e)
+ prompt = user_message
+
system_message_content = ""
try:
images = await image_generations(
request=request,
- form_data=GenerateImageForm(**{"prompt": user_message}),
+ form_data=GenerateImageForm(**{"prompt": prompt}),
user=user,
)
diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py
index ebb7483bad..f5ba75ebec 100644
--- a/backend/open_webui/utils/task.py
+++ b/backend/open_webui/utils/task.py
@@ -217,6 +217,24 @@ def tags_generation_template(
return template
+def image_prompt_generation_template(
+ template: str, messages: list[dict], user: Optional[dict] = None
+) -> str:
+ prompt = get_last_user_message(messages)
+ template = replace_prompt_variable(template, prompt)
+ template = replace_messages_variable(template, messages)
+
+ template = prompt_template(
+ template,
+ **(
+ {"user_name": user.get("name"), "user_location": user.get("location")}
+ if user
+ else {}
+ ),
+ )
+ return template
+
+
def emoji_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte
index 9c669dae53..055acbf800 100644
--- a/src/lib/components/admin/Settings/Interface.svelte
+++ b/src/lib/components/admin/Settings/Interface.svelte
@@ -24,6 +24,7 @@
TASK_MODEL: '',
TASK_MODEL_EXTERNAL: '',
TITLE_GENERATION_PROMPT_TEMPLATE: '',
+ IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_AUTOCOMPLETE_GENERATION: true,
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
TAGS_GENERATION_PROMPT_TEMPLATE: '',
@@ -140,6 +141,22 @@
+
+
{$i18n.t('Image Prompt Generation Prompt')}
+
+
+
+
+
+