diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index ea8d36aa9a..3a5b800566 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -3074,16 +3074,30 @@ EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig(
# Images
####################################
+ENABLE_IMAGE_GENERATION = PersistentConfig(
+ "ENABLE_IMAGE_GENERATION",
+ "image_generation.enable",
+ os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
+)
+
IMAGE_GENERATION_ENGINE = PersistentConfig(
"IMAGE_GENERATION_ENGINE",
"image_generation.engine",
os.getenv("IMAGE_GENERATION_ENGINE", "openai"),
)
-ENABLE_IMAGE_GENERATION = PersistentConfig(
- "ENABLE_IMAGE_GENERATION",
- "image_generation.enable",
- os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
+IMAGE_GENERATION_MODEL = PersistentConfig(
+ "IMAGE_GENERATION_MODEL",
+ "image_generation.model",
+ os.getenv("IMAGE_GENERATION_MODEL", ""),
+)
+
+IMAGE_SIZE = PersistentConfig(
+ "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
+)
+
+IMAGE_STEPS = PersistentConfig(
+ "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
)
ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig(
@@ -3285,20 +3299,52 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig(
os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
)
-IMAGE_SIZE = PersistentConfig(
- "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
+
+IMAGE_EDIT_ENGINE = PersistentConfig(
+ "IMAGE_EDIT_ENGINE",
+ "images.edit.engine",
+ os.getenv("IMAGE_EDIT_ENGINE", "openai"),
)
-IMAGE_STEPS = PersistentConfig(
- "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
+IMAGE_EDIT_MODEL = PersistentConfig(
+ "IMAGE_EDIT_MODEL",
+ "images.edit.model",
+ os.getenv("IMAGE_EDIT_MODEL", ""),
)
-IMAGE_GENERATION_MODEL = PersistentConfig(
- "IMAGE_GENERATION_MODEL",
- "image_generation.model",
- os.getenv("IMAGE_GENERATION_MODEL", ""),
+IMAGE_EDIT_SIZE = PersistentConfig(
+ "IMAGE_EDIT_SIZE", "images.edit.size", os.getenv("IMAGE_EDIT_SIZE", "")
)
+IMAGES_EDIT_OPENAI_API_BASE_URL = PersistentConfig(
+ "IMAGES_EDIT_OPENAI_API_BASE_URL",
+ "images.edit.openai.api_base_url",
+ os.getenv("IMAGES_EDIT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
+)
+IMAGES_EDIT_OPENAI_API_VERSION = PersistentConfig(
+ "IMAGES_EDIT_OPENAI_API_VERSION",
+ "images.edit.openai.api_version",
+ os.getenv("IMAGES_EDIT_OPENAI_API_VERSION", ""),
+)
+
+IMAGES_EDIT_OPENAI_API_KEY = PersistentConfig(
+ "IMAGES_EDIT_OPENAI_API_KEY",
+ "images.edit.openai.api_key",
+ os.getenv("IMAGES_EDIT_OPENAI_API_KEY", OPENAI_API_KEY),
+)
+
+IMAGES_EDIT_GEMINI_API_BASE_URL = PersistentConfig(
+ "IMAGES_EDIT_GEMINI_API_BASE_URL",
+ "images.edit.gemini.api_base_url",
+ os.getenv("IMAGES_EDIT_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
+)
+IMAGES_EDIT_GEMINI_API_KEY = PersistentConfig(
+ "IMAGES_EDIT_GEMINI_API_KEY",
+ "images.edit.gemini.api_key",
+ os.getenv("IMAGES_EDIT_GEMINI_API_KEY", GEMINI_API_KEY),
+)
+
+
####################################
# Audio
####################################
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 0512bff1eb..44f0560925 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -163,6 +163,14 @@ from open_webui.config import (
IMAGES_GEMINI_API_BASE_URL,
IMAGES_GEMINI_API_KEY,
IMAGES_GEMINI_ENDPOINT_METHOD,
+ IMAGE_EDIT_ENGINE,
+ IMAGE_EDIT_MODEL,
+ IMAGE_EDIT_SIZE,
+ IMAGES_EDIT_OPENAI_API_BASE_URL,
+ IMAGES_EDIT_OPENAI_API_KEY,
+ IMAGES_EDIT_OPENAI_API_VERSION,
+ IMAGES_EDIT_GEMINI_API_BASE_URL,
+ IMAGES_EDIT_GEMINI_API_KEY,
# Audio
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
@@ -1078,7 +1086,6 @@ app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD
-
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS
@@ -1089,6 +1096,16 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
+app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
+app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
+app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
+app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = IMAGES_EDIT_OPENAI_API_BASE_URL
+app.state.config.IMAGES_EDIT_OPENAI_API_KEY = IMAGES_EDIT_OPENAI_API_KEY
+app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = IMAGES_EDIT_OPENAI_API_VERSION
+app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = IMAGES_EDIT_GEMINI_API_BASE_URL
+app.state.config.IMAGES_EDIT_GEMINI_API_KEY = IMAGES_EDIT_GEMINI_API_KEY
+
+
########################################
#
# AUDIO
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index dd7b02550a..469cd448f8 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -1,5 +1,6 @@
import asyncio
import base64
+import uuid
import io
import json
import logging
@@ -10,18 +11,13 @@ from typing import Optional
from urllib.parse import quote
import requests
-from fastapi import (
- APIRouter,
- Depends,
- HTTPException,
- Request,
- UploadFile,
-)
+from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
+from fastapi.responses import FileResponse
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
-from open_webui.routers.files import upload_file_handler
+from open_webui.routers.files import upload_file_handler, get_file_content_by_id
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.images.comfyui import (
@@ -121,6 +117,16 @@ class ImagesConfig(BaseModel):
IMAGES_GEMINI_API_KEY: str
IMAGES_GEMINI_ENDPOINT_METHOD: str
+ IMAGE_EDIT_ENGINE: str
+ IMAGE_EDIT_MODEL: str
+ IMAGE_EDIT_SIZE: Optional[str]
+
+ IMAGES_EDIT_OPENAI_API_BASE_URL: str
+ IMAGES_EDIT_OPENAI_API_KEY: str
+ IMAGES_EDIT_OPENAI_API_VERSION: str
+ IMAGES_EDIT_GEMINI_API_BASE_URL: str
+ IMAGES_EDIT_GEMINI_API_KEY: str
+
@router.get("/config", response_model=ImagesConfig)
async def get_config(request: Request, user=Depends(get_admin_user)):
@@ -144,6 +150,14 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
+ "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
+ "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
+ "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
+ "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
+ "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
+ "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
+ "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
+ "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
}
@@ -152,6 +166,8 @@ async def update_config(
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
+
+ # Create Image
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
form_data.ENABLE_IMAGE_PROMPT_GENERATION
)
@@ -215,6 +231,28 @@ async def update_config(
form_data.IMAGES_GEMINI_ENDPOINT_METHOD
)
+ # Edit Image
+ request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
+ request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
+ request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
+
+ request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
+ form_data.IMAGES_OPENAI_API_BASE_URL
+ )
+ request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
+ form_data.IMAGES_OPENAI_API_KEY
+ )
+ request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
+ form_data.IMAGES_EDIT_OPENAI_API_VERSION
+ )
+
+ request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = (
+ form_data.IMAGES_EDIT_GEMINI_API_BASE_URL
+ )
+ request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = (
+ form_data.IMAGES_EDIT_GEMINI_API_KEY
+ )
+
return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
@@ -235,6 +273,14 @@ async def update_config(
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
+ "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
+ "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
+ "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
+ "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
+ "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
+ "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
+ "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
+ "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
}
@@ -674,3 +720,255 @@ async def image_generations(
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
+
+
+class EditImageForm(BaseModel):
+ image: str | list[str] # base64-encoded image(s) or URL(s)
+ prompt: str
+ model: Optional[str] = None
+ size: Optional[str] = None
+ n: Optional[int] = None
+ negative_prompt: Optional[str] = None
+
+
+@router.post("/edit")
+async def image_edits(
+ request: Request,
+ form_data: EditImageForm,
+ user=Depends(get_verified_user),
+):
+
+ size = None
+ width, height = None, None
+ if (
+ request.app.state.config.IMAGE_EDIT_SIZE
+ and "x" in request.app.state.config.IMAGE_EDIT_SIZE
+ ) or (form_data.size and "x" in form_data.size):
+ size = (
+ form_data.size
+ if form_data.size
+ else request.app.state.config.IMAGE_EDIT_SIZE
+ )
+ width, height = tuple(map(int, size.split("x")))
+
+ model = (
+ request.app.state.config.IMAGE_EDIT_MODEL
+ if form_data.model is None
+ else form_data.model
+ )
+
+ def load_url_image(string):
+ if string.startswith("http://") or string.startswith("https://"):
+ r = requests.get(string)
+ r.raise_for_status()
+ image_data = base64.b64encode(r.content).decode("utf-8")
+ return f"data:{r.headers['content-type']};base64,{image_data}"
+
+ elif string.startswith("/api/v1/files"):
+ file_id = string.split("/api/v1/files/")[1].split("/content")[0]
+ file_response = get_file_content_by_id(file_id, user)
+
+ if isinstance(file_response, FileResponse):
+ file_bytes = file_response.body
+ mime_type = file_response.headers.get("content-type", "image/png")
+ image_data = base64.b64encode(file_bytes).decode("utf-8")
+ return f"data:{mime_type};base64,{image_data}"
+ return string
+
+ # Load image(s) from URL(s) if necessary
+ if isinstance(form_data.image, str):
+ form_data.image = load_url_image(form_data.image)
+ elif isinstance(form_data.image, list):
+ form_data.image = [load_url_image(img) for img in form_data.image]
+
+ r = None
+ try:
+ if request.app.state.config.IMAGE_EDIT_ENGINE == "openai":
+ headers = {
+ "Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}",
+ }
+
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
+ headers = include_user_info_headers(headers, user)
+
+ data = {
+ "model": model,
+ "prompt": form_data.prompt,
+ **({"n": form_data.n} if form_data.n else {}),
+ **({"size": size} if size else {}),
+ **(
+ {}
+ if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
+ else {"response_format": "b64_json"}
+ ),
+ }
+
+ def get_image_file_item(base64_string):
+ data = base64_string
+ header, encoded = data.split(",", 1)
+ mime_type = header.split(";")[0].lstrip("data:")
+ image_data = base64.b64decode(encoded)
+ return (
+ "image",
+ (
+ f"{uuid.uuid4()}.png",
+ io.BytesIO(image_data),
+ mime_type if mime_type else "image/png",
+ ),
+ )
+
+ files = []
+ if isinstance(form_data.image, str):
+ files = [get_image_file_item(form_data.image)]
+ elif isinstance(form_data.image, list):
+ for img in form_data.image:
+ files.append(get_image_file_item(img))
+
+ url_search_params = ""
+ if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
+ url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}"
+
+ # Use asyncio.to_thread for the requests.post call
+ r = await asyncio.to_thread(
+ requests.post,
+ url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/edits{url_search_params}",
+ headers=headers,
+ files=files,
+ data=data,
+ )
+
+ r.raise_for_status()
+ res = r.json()
+
+ images = []
+ for image in res["data"]:
+ if image_url := image.get("url", None):
+ image_data, content_type = get_image_data(image_url, headers)
+ else:
+ image_data, content_type = get_image_data(image["b64_json"])
+
+ url = upload_image(request, image_data, content_type, data, user)
+ images.append({"url": url})
+ return images
+
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
+ headers = {
+ "Content-Type": "application/json",
+ "x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
+ }
+
+ model = f"{model}:generateContent"
+ data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
+
+ if isinstance(form_data.image, str):
+ data["contents"][0]["parts"].append(
+ {
+ "inline_data": {
+ "mime_type": "image/png",
+ "data": form_data.image.split(",", 1)[1],
+ }
+ }
+ )
+ elif isinstance(form_data.image, list):
+ data["contents"][0]["parts"].extend(
+ [
+ {
+ "inline_data": {
+ "mime_type": "image/png",
+ "data": image.split(",", 1)[1],
+ }
+ }
+ for image in form_data.image
+ ]
+ )
+
+ # Use asyncio.to_thread for the requests.post call
+ r = await asyncio.to_thread(
+ requests.post,
+ url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
+ json=data,
+ headers=headers,
+ )
+
+ r.raise_for_status()
+ res = r.json()
+
+ images = []
+ for image in res["candidates"]:
+ for part in image["content"]["parts"]:
+ if part.get("inlineData", {}).get("data"):
+ image_data, content_type = get_image_data(
+ part["inlineData"]["data"]
+ )
+ url = upload_image(
+ request, image_data, content_type, data, user
+ )
+ images.append({"url": url})
+
+ return images
+
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
+ data = {
+ "prompt": form_data.prompt,
+ "width": width,
+ "height": height,
+ "n": form_data.n,
+ }
+
+ if request.app.state.config.IMAGE_EDIT_STEPS is not None:
+ data["steps"] = request.app.state.config.IMAGE_EDIT_STEPS
+
+ if form_data.negative_prompt is not None:
+ data["negative_prompt"] = form_data.negative_prompt
+
+ form_data = ComfyUICreateImageForm(
+ **{
+ "workflow": ComfyUIWorkflow(
+ **{
+ "workflow": request.app.state.config.COMFYUI_WORKFLOW,
+ "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
+ }
+ ),
+ **data,
+ }
+ )
+ res = await comfyui_create_image(
+ model,
+ form_data,
+ user.id,
+ request.app.state.config.COMFYUI_BASE_URL,
+ request.app.state.config.COMFYUI_API_KEY,
+ )
+ log.debug(f"res: {res}")
+
+ images = []
+
+ for image in res["data"]:
+ headers = None
+ if request.app.state.config.COMFYUI_API_KEY:
+ headers = {
+ "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
+ }
+
+ image_data, content_type = get_image_data(image["url"], headers)
+ url = upload_image(
+ request,
+ image_data,
+ content_type,
+ form_data.model_dump(exclude_none=True),
+ user,
+ )
+ images.append({"url": url})
+ return images
+ except Exception as e:
+ error = e
+ if r != None:
+ data = r.text
+ try:
+ data = json.loads(data)
+ if "error" in data:
+ error = data["error"]["message"]
+ except Exception:
+ error = data
+
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index e4cace689a..e3cad01222 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -47,7 +47,8 @@ from open_webui.routers.retrieval import (
from open_webui.routers.images import (
image_generations,
CreateImageForm,
- upload_image,
+ image_edits,
+ EditImageForm,
)
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
@@ -717,9 +718,31 @@ async def chat_web_search_handler(
return form_data
+def get_last_images(message_list):
+ images = []
+ for message in reversed(message_list):
+ images_flag = False
+ for file in message.get("files", []):
+ if file.get("type") == "image":
+ images.append(file.get("url"))
+ images_flag = True
+
+ if images_flag:
+ break
+
+ return images
+
+
async def chat_image_generation_handler(
request: Request, form_data: dict, extra_params: dict, user
):
+ metadata = extra_params.get("__metadata__", {})
+ chat_id = metadata.get("chat_id", None)
+ if not chat_id:
+ return form_data
+
+ chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
+
__event_emitter__ = extra_params["__event_emitter__"]
await __event_emitter__(
{
@@ -728,87 +751,151 @@ async def chat_image_generation_handler(
}
)
- messages = form_data["messages"]
- user_message = get_last_user_message(messages)
+ messages_map = chat.chat.get("history", {}).get("messages", {})
+ message_id = chat.chat.get("history", {}).get("currentId")
+ message_list = get_message_list(messages_map, message_id)
+ user_message = get_last_user_message(message_list)
prompt = user_message
- negative_prompt = ""
-
- if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
- try:
- res = await generate_image_prompt(
- request,
- {
- "model": form_data["model"],
- "messages": messages,
- },
- user,
- )
-
- response = res["choices"][0]["message"]["content"]
-
- try:
- bracket_start = response.find("{")
- bracket_end = response.rfind("}") + 1
-
- if bracket_start == -1 or bracket_end == -1:
- raise Exception("No JSON object found in the response")
-
- response = response[bracket_start:bracket_end]
- response = json.loads(response)
- prompt = response.get("prompt", [])
- except Exception as e:
- prompt = user_message
-
- except Exception as e:
- log.exception(e)
- prompt = user_message
+ input_images = get_last_images(message_list)
system_message_content = ""
+ if len(input_images) == 0:
+ # Create image(s)
+ if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
+ try:
+ res = await generate_image_prompt(
+ request,
+ {
+ "model": form_data["model"],
+ "messages": form_data["messages"],
+ },
+ user,
+ )
- try:
- images = await image_generations(
- request=request,
- form_data=CreateImageForm(**{"prompt": prompt}),
- user=user,
- )
+ response = res["choices"][0]["message"]["content"]
- await __event_emitter__(
- {
- "type": "status",
- "data": {"description": "Image created", "done": True},
- }
- )
+ try:
+ bracket_start = response.find("{")
+ bracket_end = response.rfind("}") + 1
- await __event_emitter__(
- {
- "type": "files",
- "data": {
- "files": [
- {
- "type": "image",
- "url": image["url"],
- }
- for image in images
- ]
- },
- }
- )
+ if bracket_start == -1 or bracket_end == -1:
+ raise Exception("No JSON object found in the response")
- system_message_content = "