feat: image edit support

This commit is contained in:
Timothy Jaeryang Baek 2025-11-05 03:31:37 -05:00
parent 8d34fcb586
commit 72f8539fd2
5 changed files with 623 additions and 119 deletions

View file

@ -3074,16 +3074,30 @@ EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig(
# Images # Images
#################################### ####################################
ENABLE_IMAGE_GENERATION = PersistentConfig(
"ENABLE_IMAGE_GENERATION",
"image_generation.enable",
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
)
IMAGE_GENERATION_ENGINE = PersistentConfig( IMAGE_GENERATION_ENGINE = PersistentConfig(
"IMAGE_GENERATION_ENGINE", "IMAGE_GENERATION_ENGINE",
"image_generation.engine", "image_generation.engine",
os.getenv("IMAGE_GENERATION_ENGINE", "openai"), os.getenv("IMAGE_GENERATION_ENGINE", "openai"),
) )
ENABLE_IMAGE_GENERATION = PersistentConfig( IMAGE_GENERATION_MODEL = PersistentConfig(
"ENABLE_IMAGE_GENERATION", "IMAGE_GENERATION_MODEL",
"image_generation.enable", "image_generation.model",
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", os.getenv("IMAGE_GENERATION_MODEL", ""),
)
IMAGE_SIZE = PersistentConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
)
IMAGE_STEPS = PersistentConfig(
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
) )
ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig( ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig(
@ -3285,20 +3299,52 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig(
os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""), os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
) )
IMAGE_SIZE = PersistentConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") IMAGE_EDIT_ENGINE = PersistentConfig(
"IMAGE_EDIT_ENGINE",
"images.edit.engine",
os.getenv("IMAGE_EDIT_ENGINE", "openai"),
) )
IMAGE_STEPS = PersistentConfig( IMAGE_EDIT_MODEL = PersistentConfig(
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) "IMAGE_EDIT_MODEL",
"images.edit.model",
os.getenv("IMAGE_EDIT_MODEL", ""),
) )
IMAGE_GENERATION_MODEL = PersistentConfig( IMAGE_EDIT_SIZE = PersistentConfig(
"IMAGE_GENERATION_MODEL", "IMAGE_EDIT_SIZE", "images.edit.size", os.getenv("IMAGE_EDIT_SIZE", "")
"image_generation.model",
os.getenv("IMAGE_GENERATION_MODEL", ""),
) )
IMAGES_EDIT_OPENAI_API_BASE_URL = PersistentConfig(
"IMAGES_EDIT_OPENAI_API_BASE_URL",
"images.edit.openai.api_base_url",
os.getenv("IMAGES_EDIT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
IMAGES_EDIT_OPENAI_API_VERSION = PersistentConfig(
"IMAGES_EDIT_OPENAI_API_VERSION",
"images.edit.openai.api_version",
os.getenv("IMAGES_EDIT_OPENAI_API_VERSION", ""),
)
IMAGES_EDIT_OPENAI_API_KEY = PersistentConfig(
"IMAGES_EDIT_OPENAI_API_KEY",
"images.edit.openai.api_key",
os.getenv("IMAGES_EDIT_OPENAI_API_KEY", OPENAI_API_KEY),
)
IMAGES_EDIT_GEMINI_API_BASE_URL = PersistentConfig(
"IMAGES_EDIT_GEMINI_API_BASE_URL",
"images.edit.gemini.api_base_url",
os.getenv("IMAGES_EDIT_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
)
IMAGES_EDIT_GEMINI_API_KEY = PersistentConfig(
"IMAGES_EDIT_GEMINI_API_KEY",
"images.edit.gemini.api_key",
os.getenv("IMAGES_EDIT_GEMINI_API_KEY", GEMINI_API_KEY),
)
#################################### ####################################
# Audio # Audio
#################################### ####################################

View file

@ -163,6 +163,14 @@ from open_webui.config import (
IMAGES_GEMINI_API_BASE_URL, IMAGES_GEMINI_API_BASE_URL,
IMAGES_GEMINI_API_KEY, IMAGES_GEMINI_API_KEY,
IMAGES_GEMINI_ENDPOINT_METHOD, IMAGES_GEMINI_ENDPOINT_METHOD,
IMAGE_EDIT_ENGINE,
IMAGE_EDIT_MODEL,
IMAGE_EDIT_SIZE,
IMAGES_EDIT_OPENAI_API_BASE_URL,
IMAGES_EDIT_OPENAI_API_KEY,
IMAGES_EDIT_OPENAI_API_VERSION,
IMAGES_EDIT_GEMINI_API_BASE_URL,
IMAGES_EDIT_GEMINI_API_KEY,
# Audio # Audio
AUDIO_STT_ENGINE, AUDIO_STT_ENGINE,
AUDIO_STT_MODEL, AUDIO_STT_MODEL,
@ -1078,7 +1086,6 @@ app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS
@ -1089,6 +1096,16 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = IMAGES_EDIT_OPENAI_API_BASE_URL
app.state.config.IMAGES_EDIT_OPENAI_API_KEY = IMAGES_EDIT_OPENAI_API_KEY
app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = IMAGES_EDIT_OPENAI_API_VERSION
app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = IMAGES_EDIT_GEMINI_API_BASE_URL
app.state.config.IMAGES_EDIT_GEMINI_API_KEY = IMAGES_EDIT_GEMINI_API_KEY
######################################## ########################################
# #
# AUDIO # AUDIO

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import base64 import base64
import uuid
import io import io
import json import json
import logging import logging
@ -10,18 +11,13 @@ from typing import Optional
from urllib.parse import quote from urllib.parse import quote
import requests import requests
from fastapi import ( from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
APIRouter, from fastapi.responses import FileResponse
Depends,
HTTPException,
Request,
UploadFile,
)
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file_handler from open_webui.routers.files import upload_file_handler, get_file_content_by_id
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import include_user_info_headers from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.images.comfyui import ( from open_webui.utils.images.comfyui import (
@ -121,6 +117,16 @@ class ImagesConfig(BaseModel):
IMAGES_GEMINI_API_KEY: str IMAGES_GEMINI_API_KEY: str
IMAGES_GEMINI_ENDPOINT_METHOD: str IMAGES_GEMINI_ENDPOINT_METHOD: str
IMAGE_EDIT_ENGINE: str
IMAGE_EDIT_MODEL: str
IMAGE_EDIT_SIZE: Optional[str]
IMAGES_EDIT_OPENAI_API_BASE_URL: str
IMAGES_EDIT_OPENAI_API_KEY: str
IMAGES_EDIT_OPENAI_API_VERSION: str
IMAGES_EDIT_GEMINI_API_BASE_URL: str
IMAGES_EDIT_GEMINI_API_KEY: str
@router.get("/config", response_model=ImagesConfig) @router.get("/config", response_model=ImagesConfig)
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
@ -144,6 +150,14 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
} }
@ -152,6 +166,8 @@ async def update_config(
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user) request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
): ):
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
# Create Image
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ( request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
form_data.ENABLE_IMAGE_PROMPT_GENERATION form_data.ENABLE_IMAGE_PROMPT_GENERATION
) )
@ -215,6 +231,28 @@ async def update_config(
form_data.IMAGES_GEMINI_ENDPOINT_METHOD form_data.IMAGES_GEMINI_ENDPOINT_METHOD
) )
# Edit Image
request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
form_data.IMAGES_OPENAI_API_BASE_URL
)
request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
form_data.IMAGES_OPENAI_API_KEY
)
request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
form_data.IMAGES_EDIT_OPENAI_API_VERSION
)
request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = (
form_data.IMAGES_EDIT_GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = (
form_data.IMAGES_EDIT_GEMINI_API_KEY
)
return { return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION, "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
@ -235,6 +273,14 @@ async def update_config(
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
} }
@ -674,3 +720,255 @@ async def image_generations(
if "error" in data: if "error" in data:
error = data["error"]["message"] error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
class EditImageForm(BaseModel):
image: str | list[str] # base64-encoded image(s) or URL(s)
prompt: str
model: Optional[str] = None
size: Optional[str] = None
n: Optional[int] = None
negative_prompt: Optional[str] = None
@router.post("/edit")
async def image_edits(
request: Request,
form_data: EditImageForm,
user=Depends(get_verified_user),
):
size = None
width, height = None, None
if (
request.app.state.config.IMAGE_EDIT_SIZE
and "x" in request.app.state.config.IMAGE_EDIT_SIZE
) or (form_data.size and "x" in form_data.size):
size = (
form_data.size
if form_data.size
else request.app.state.config.IMAGE_EDIT_SIZE
)
width, height = tuple(map(int, size.split("x")))
model = (
request.app.state.config.IMAGE_EDIT_MODEL
if form_data.model is None
else form_data.model
)
def load_url_image(string):
if string.startswith("http://") or string.startswith("https://"):
r = requests.get(string)
r.raise_for_status()
image_data = base64.b64encode(r.content).decode("utf-8")
return f"data:{r.headers['content-type']};base64,{image_data}"
elif string.startswith("/api/v1/files"):
file_id = string.split("/api/v1/files/")[1].split("/content")[0]
file_response = get_file_content_by_id(file_id, user)
if isinstance(file_response, FileResponse):
file_bytes = file_response.body
mime_type = file_response.headers.get("content-type", "image/png")
image_data = base64.b64encode(file_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_data}"
return string
# Load image(s) from URL(s) if necessary
if isinstance(form_data.image, str):
form_data.image = load_url_image(form_data.image)
elif isinstance(form_data.image, list):
form_data.image = [load_url_image(img) for img in form_data.image]
r = None
try:
if request.app.state.config.IMAGE_EDIT_ENGINE == "openai":
headers = {
"Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers = include_user_info_headers(headers, user)
data = {
"model": model,
"prompt": form_data.prompt,
**({"n": form_data.n} if form_data.n else {}),
**({"size": size} if size else {}),
**(
{}
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
else {"response_format": "b64_json"}
),
}
def get_image_file_item(base64_string):
data = base64_string
header, encoded = data.split(",", 1)
mime_type = header.split(";")[0].lstrip("data:")
image_data = base64.b64decode(encoded)
return (
"image",
(
f"{uuid.uuid4()}.png",
io.BytesIO(image_data),
mime_type if mime_type else "image/png",
),
)
files = []
if isinstance(form_data.image, str):
files = [get_image_file_item(form_data.image)]
elif isinstance(form_data.image, list):
for img in form_data.image:
files.append(get_image_file_item(img))
url_search_params = ""
if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}"
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/edits{url_search_params}",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["data"]:
if image_url := image.get("url", None):
image_data, content_type = get_image_data(image_url, headers)
else:
image_data, content_type = get_image_data(image["b64_json"])
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {
"Content-Type": "application/json",
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
}
model = f"{model}:generateContent"
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
if isinstance(form_data.image, str):
data["contents"][0]["parts"].append(
{
"inline_data": {
"mime_type": "image/png",
"data": form_data.image.split(",", 1)[1],
}
}
)
elif isinstance(form_data.image, list):
data["contents"][0]["parts"].extend(
[
{
"inline_data": {
"mime_type": "image/png",
"data": image.split(",", 1)[1],
}
}
for image in form_data.image
]
)
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["candidates"]:
for part in image["content"]["parts"]:
if part.get("inlineData", {}).get("data"):
image_data, content_type = get_image_data(
part["inlineData"]["data"]
)
url = upload_image(
request, image_data, content_type, data, user
)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if request.app.state.config.IMAGE_EDIT_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_EDIT_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
form_data = ComfyUICreateImageForm(
**{
"workflow": ComfyUIWorkflow(
**{
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}
),
**data,
}
)
res = await comfyui_create_image(
model,
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_API_KEY,
)
log.debug(f"res: {res}")
images = []
for image in res["data"]:
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
image_data, content_type = get_image_data(image["url"], headers)
url = upload_image(
request,
image_data,
content_type,
form_data.model_dump(exclude_none=True),
user,
)
images.append({"url": url})
return images
except Exception as e:
error = e
if r != None:
data = r.text
try:
data = json.loads(data)
if "error" in data:
error = data["error"]["message"]
except Exception:
error = data
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))

View file

@ -47,7 +47,8 @@ from open_webui.routers.retrieval import (
from open_webui.routers.images import ( from open_webui.routers.images import (
image_generations, image_generations,
CreateImageForm, CreateImageForm,
upload_image, image_edits,
EditImageForm,
) )
from open_webui.routers.pipelines import ( from open_webui.routers.pipelines import (
process_pipeline_inlet_filter, process_pipeline_inlet_filter,
@ -717,9 +718,31 @@ async def chat_web_search_handler(
return form_data return form_data
def get_last_images(message_list):
images = []
for message in reversed(message_list):
images_flag = False
for file in message.get("files", []):
if file.get("type") == "image":
images.append(file.get("url"))
images_flag = True
if images_flag:
break
return images
async def chat_image_generation_handler( async def chat_image_generation_handler(
request: Request, form_data: dict, extra_params: dict, user request: Request, form_data: dict, extra_params: dict, user
): ):
metadata = extra_params.get("__metadata__", {})
chat_id = metadata.get("chat_id", None)
if not chat_id:
return form_data
chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
__event_emitter__ = extra_params["__event_emitter__"] __event_emitter__ = extra_params["__event_emitter__"]
await __event_emitter__( await __event_emitter__(
{ {
@ -728,87 +751,151 @@ async def chat_image_generation_handler(
} }
) )
messages = form_data["messages"] messages_map = chat.chat.get("history", {}).get("messages", {})
user_message = get_last_user_message(messages) message_id = chat.chat.get("history", {}).get("currentId")
message_list = get_message_list(messages_map, message_id)
user_message = get_last_user_message(message_list)
prompt = user_message prompt = user_message
negative_prompt = "" input_images = get_last_images(message_list)
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": messages,
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
prompt = user_message
except Exception as e:
log.exception(e)
prompt = user_message
system_message_content = "" system_message_content = ""
if len(input_images) == 0:
# Create image(s)
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": form_data["messages"],
},
user,
)
try: response = res["choices"][0]["message"]["content"]
images = await image_generations(
request=request,
form_data=CreateImageForm(**{"prompt": prompt}),
user=user,
)
await __event_emitter__( try:
{ bracket_start = response.find("{")
"type": "status", bracket_end = response.rfind("}") + 1
"data": {"description": "Image created", "done": True},
}
)
await __event_emitter__( if bracket_start == -1 or bracket_end == -1:
{ raise Exception("No JSON object found in the response")
"type": "files",
"data": {
"files": [
{
"type": "image",
"url": image["url"],
}
for image in images
]
},
}
)
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>" response = response[bracket_start:bracket_end]
except Exception as e: response = json.loads(response)
log.exception(e) prompt = response.get("prompt", [])
await __event_emitter__( except Exception as e:
{ prompt = user_message
"type": "status",
"data": {
"description": f"An error occurred while generating an image",
"done": True,
},
}
)
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>" except Exception as e:
log.exception(e)
prompt = user_message
try:
images = await image_generations(
request=request,
form_data=CreateImageForm(**{"prompt": prompt}),
user=user,
)
await __event_emitter__(
{
"type": "status",
"data": {"description": "Image created", "done": True},
}
)
await __event_emitter__(
{
"type": "files",
"data": {
"files": [
{
"type": "image",
"url": image["url"],
}
for image in images
]
},
}
)
system_message_content = "<context>The requested image has been created and is now being shown to the user. Let them know that it has been generated.</context>"
except Exception as e:
log.debug(e)
error_message = ""
if isinstance(e, HTTPException):
if e.detail and isinstance(e.detail, dict):
error_message = e.detail.get("message", str(e.detail))
else:
error_message = str(e.detail)
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"An error occurred while generating an image",
"done": True,
},
}
)
system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>"
else:
# Edit image(s)
try:
images = await image_edits(
request=request,
form_data=EditImageForm(**{"prompt": prompt, "image": input_images}),
user=user,
)
await __event_emitter__(
{
"type": "status",
"data": {"description": "Image created", "done": True},
}
)
await __event_emitter__(
{
"type": "files",
"data": {
"files": [
{
"type": "image",
"url": image["url"],
}
for image in images
]
},
}
)
system_message_content = "<context>The requested image has been created and is now being shown to the user. Let them know that it has been generated.</context>"
except Exception as e:
log.debug(e)
error_message = ""
if isinstance(e, HTTPException):
if e.detail and isinstance(e.detail, dict):
error_message = e.detail.get("message", str(e.detail))
else:
error_message = str(e.detail)
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"An error occurred while generating an image",
"done": True,
},
}
)
system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>"
if system_message_content: if system_message_content:
form_data["messages"] = add_or_update_system_message( form_data["messages"] = add_or_update_system_message(

View file

@ -29,7 +29,7 @@
let config = null; let config = null;
let showComfyUIWorkflowEditor = false; let showComfyUIWorkflowEditor = false;
let requiredWorkflowNodes = [ let REQUIRED_WORKFLOW_NODES = [
{ {
type: 'prompt', type: 'prompt',
key: 'text', key: 'text',
@ -62,6 +62,29 @@
} }
]; ];
let REQUIRED_EDIT_WORKFLOW_NODES = [
{
type: 'prompt',
key: 'text',
node_ids: ''
},
{
type: 'model',
key: 'ckpt_name',
node_ids: ''
},
{
type: 'width',
key: 'width',
node_ids: ''
},
{
type: 'height',
key: 'height',
node_ids: ''
}
];
const getModels = async () => { const getModels = async () => {
models = await getImageGenerationModels(localStorage.token).catch((error) => { models = await getImageGenerationModels(localStorage.token).catch((error) => {
toast.error(`${error}`); toast.error(`${error}`);
@ -137,7 +160,7 @@
} }
if (config?.COMFYUI_WORKFLOW) { if (config?.COMFYUI_WORKFLOW) {
config.COMFYUI_WORKFLOW_NODES = requiredWorkflowNodes.map((node) => { config.COMFYUI_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => {
return { return {
type: node.type, type: node.type,
key: node.key, key: node.key,
@ -178,7 +201,7 @@
} }
} }
requiredWorkflowNodes = requiredWorkflowNodes.map((node) => { REQUIRED_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => {
const n = config.COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node; const n = config.COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node;
console.debug(n); console.debug(n);
@ -665,7 +688,7 @@
</div> </div>
<div class="mt-1 text-xs flex flex-col gap-1.5"> <div class="mt-1 text-xs flex flex-col gap-1.5">
{#each requiredWorkflowNodes as node} {#each REQUIRED_WORKFLOW_NODES as node}
<div class="flex w-full flex-col"> <div class="flex w-full flex-col">
<div class="shrink-0"> <div class="shrink-0">
<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500"> <div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
@ -791,13 +814,13 @@
placeholder={$i18n.t('Select Engine')} placeholder={$i18n.t('Select Engine')}
> >
<option value="openai">{$i18n.t('Default (Open AI)')}</option> <option value="openai">{$i18n.t('Default (Open AI)')}</option>
<option value="comfyui">{$i18n.t('ComfyUI')}</option> <!-- <option value="comfyui">{$i18n.t('ComfyUI')}</option> -->
<option value="comfyui">{$i18n.t('Gemini')}</option> <option value="gemini">{$i18n.t('Gemini')}</option>
</select> </select>
</div> </div>
</div> </div>
{#if config.ENABLE_IMAGE_EDIT} {#if config.ENABLE_IMAGE_GENERATION}
<div class="mb-2.5"> <div class="mb-2.5">
<div class="flex w-full justify-between items-center"> <div class="flex w-full justify-between items-center">
<div class="text-xs pr-2"> <div class="text-xs pr-2">
@ -918,7 +941,7 @@
<input <input
class="w-full text-sm bg-transparent outline-hidden text-right" class="w-full text-sm bg-transparent outline-hidden text-right"
placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')} placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
bind:value={config.COMFYUI_BASE_URL} bind:value={config.IMAGES_EDIT_COMFYUI_BASE_URL}
/> />
</div> </div>
<button <button
@ -967,7 +990,7 @@
<SensitiveInput <SensitiveInput
inputClassName="text-right w-full" inputClassName="text-right w-full"
placeholder={$i18n.t('sk-1234')} placeholder={$i18n.t('sk-1234')}
bind:value={config.COMFYUI_API_KEY} bind:value={config.IMAGES_EDIT_COMFYUI_API_KEY}
required={false} required={false}
/> />
</div> </div>
@ -977,7 +1000,7 @@
<div class="mb-2.5"> <div class="mb-2.5">
<input <input
id="upload-comfyui-workflow-input" id="upload-comfyui-edit-workflow-input"
hidden hidden
type="file" type="file"
accept=".json" accept=".json"
@ -986,7 +1009,7 @@
const reader = new FileReader(); const reader = new FileReader();
reader.onload = (e) => { reader.onload = (e) => {
config.COMFYUI_WORKFLOW = e.target.result; config.IMAGES_EDIT_COMFYUI_WORKFLOW = e.target.result;
e.target.value = null; e.target.value = null;
}; };
@ -1002,7 +1025,7 @@
<div class="flex w-full"> <div class="flex w-full">
<div class="flex-1 mr-2 justify-end flex gap-1"> <div class="flex-1 mr-2 justify-end flex gap-1">
{#if config.COMFYUI_WORKFLOW} {#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
<button <button
class="text-xs text-gray-700 dark:text-gray-400 underline" class="text-xs text-gray-700 dark:text-gray-400 underline"
type="button" type="button"
@ -1022,7 +1045,7 @@
type="button" type="button"
aria-label={$i18n.t('Click here to upload a workflow.json file.')} aria-label={$i18n.t('Click here to upload a workflow.json file.')}
on:click={() => { on:click={() => {
document.getElementById('upload-comfyui-workflow-input')?.click(); document.getElementById('upload-comfyui-edit-workflow-input')?.click();
}} }}
> >
{$i18n.t('Upload')} {$i18n.t('Upload')}
@ -1035,28 +1058,20 @@
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500"> <div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
<CodeEditorModal <CodeEditorModal
bind:show={showComfyUIWorkflowEditor} bind:show={showComfyUIWorkflowEditor}
value={config.COMFYUI_WORKFLOW} value={config.IMAGES_EDIT_COMFYUI_WORKFLOW}
lang="json" lang="json"
onChange={(e) => { onChange={(e) => {
config.COMFYUI_WORKFLOW = e; config.IMAGES_EDIT_COMFYUI_WORKFLOW = e;
}} }}
onSave={() => { onSave={() => {
console.log('Saved'); console.log('Saved');
}} }}
/> />
<!-- {#if config.COMFYUI_WORKFLOW}
<Textarea
class="w-full rounded-lg my-1 py-2 px-3 text-xs bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden disabled:text-gray-600 resize-none"
rows="10"
bind:value={config.COMFYUI_WORKFLOW}
required
/>
{/if} -->
{$i18n.t('Make sure to export a workflow.json file as API format from ComfyUI.')} {$i18n.t('Make sure to export a workflow.json file as API format from ComfyUI.')}
</div> </div>
</div> </div>
{#if config.COMFYUI_WORKFLOW} {#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
<div class="mb-2.5"> <div class="mb-2.5">
<div class="flex w-full justify-between items-center"> <div class="flex w-full justify-between items-center">
<div class="text-xs pr-2 shrink-0"> <div class="text-xs pr-2 shrink-0">
@ -1067,7 +1082,7 @@
</div> </div>
<div class="mt-1 text-xs flex flex-col gap-1.5"> <div class="mt-1 text-xs flex flex-col gap-1.5">
{#each requiredWorkflowNodes as node} {#each REQUIRED_EDIT_WORKFLOW_NODES as node}
<div class="flex w-full flex-col"> <div class="flex w-full flex-col">
<div class="shrink-0"> <div class="shrink-0">
<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500"> <div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
@ -1111,6 +1126,47 @@
</div> </div>
</div> </div>
{/if} {/if}
{:else if config?.IMAGE_GENERATION_ENGINE === 'gemini'}
<div class="mb-2.5">
<div class="flex w-full justify-between items-center">
<div class="text-xs pr-2 shrink-0">
<div class="">
{$i18n.t('Gemini Base URL')}
</div>
</div>
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full text-sm bg-transparent outline-hidden text-right"
placeholder={$i18n.t('API Base URL')}
bind:value={config.IMAGES_EDIT_GEMINI_API_BASE_URL}
/>
</div>
</div>
</div>
</div>
<div class="mb-2.5">
<div class="flex w-full justify-between items-center">
<div class="text-xs pr-2 shrink-0">
<div class="">
{$i18n.t('Gemini API Key')}
</div>
</div>
<div class="flex w-full">
<div class="flex-1">
<SensitiveInput
inputClassName="text-right w-full"
placeholder={$i18n.t('API Key')}
bind:value={config.IMAGES_EDIT_GEMINI_API_KEY}
required={true}
/>
</div>
</div>
</div>
</div>
{/if} {/if}
</div> </div>
</div> </div>