mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: image edit support
This commit is contained in:
parent
8d34fcb586
commit
72f8539fd2
5 changed files with 623 additions and 119 deletions
|
|
@ -3074,16 +3074,30 @@ EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig(
|
|||
# Images
|
||||
####################################
|
||||
|
||||
ENABLE_IMAGE_GENERATION = PersistentConfig(
|
||||
"ENABLE_IMAGE_GENERATION",
|
||||
"image_generation.enable",
|
||||
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
|
||||
)
|
||||
|
||||
IMAGE_GENERATION_ENGINE = PersistentConfig(
|
||||
"IMAGE_GENERATION_ENGINE",
|
||||
"image_generation.engine",
|
||||
os.getenv("IMAGE_GENERATION_ENGINE", "openai"),
|
||||
)
|
||||
|
||||
ENABLE_IMAGE_GENERATION = PersistentConfig(
|
||||
"ENABLE_IMAGE_GENERATION",
|
||||
"image_generation.enable",
|
||||
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
|
||||
IMAGE_GENERATION_MODEL = PersistentConfig(
|
||||
"IMAGE_GENERATION_MODEL",
|
||||
"image_generation.model",
|
||||
os.getenv("IMAGE_GENERATION_MODEL", ""),
|
||||
)
|
||||
|
||||
IMAGE_SIZE = PersistentConfig(
|
||||
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
||||
)
|
||||
|
||||
IMAGE_STEPS = PersistentConfig(
|
||||
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
|
||||
)
|
||||
|
||||
ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig(
|
||||
|
|
@ -3285,20 +3299,52 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig(
|
|||
os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
|
||||
)
|
||||
|
||||
IMAGE_SIZE = PersistentConfig(
|
||||
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
||||
|
||||
IMAGE_EDIT_ENGINE = PersistentConfig(
|
||||
"IMAGE_EDIT_ENGINE",
|
||||
"images.edit.engine",
|
||||
os.getenv("IMAGE_EDIT_ENGINE", "openai"),
|
||||
)
|
||||
|
||||
IMAGE_STEPS = PersistentConfig(
|
||||
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
|
||||
IMAGE_EDIT_MODEL = PersistentConfig(
|
||||
"IMAGE_EDIT_MODEL",
|
||||
"images.edit.model",
|
||||
os.getenv("IMAGE_EDIT_MODEL", ""),
|
||||
)
|
||||
|
||||
IMAGE_GENERATION_MODEL = PersistentConfig(
|
||||
"IMAGE_GENERATION_MODEL",
|
||||
"image_generation.model",
|
||||
os.getenv("IMAGE_GENERATION_MODEL", ""),
|
||||
IMAGE_EDIT_SIZE = PersistentConfig(
|
||||
"IMAGE_EDIT_SIZE", "images.edit.size", os.getenv("IMAGE_EDIT_SIZE", "")
|
||||
)
|
||||
|
||||
IMAGES_EDIT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"IMAGES_EDIT_OPENAI_API_BASE_URL",
|
||||
"images.edit.openai.api_base_url",
|
||||
os.getenv("IMAGES_EDIT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
||||
)
|
||||
IMAGES_EDIT_OPENAI_API_VERSION = PersistentConfig(
|
||||
"IMAGES_EDIT_OPENAI_API_VERSION",
|
||||
"images.edit.openai.api_version",
|
||||
os.getenv("IMAGES_EDIT_OPENAI_API_VERSION", ""),
|
||||
)
|
||||
|
||||
IMAGES_EDIT_OPENAI_API_KEY = PersistentConfig(
|
||||
"IMAGES_EDIT_OPENAI_API_KEY",
|
||||
"images.edit.openai.api_key",
|
||||
os.getenv("IMAGES_EDIT_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||
)
|
||||
|
||||
IMAGES_EDIT_GEMINI_API_BASE_URL = PersistentConfig(
|
||||
"IMAGES_EDIT_GEMINI_API_BASE_URL",
|
||||
"images.edit.gemini.api_base_url",
|
||||
os.getenv("IMAGES_EDIT_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
|
||||
)
|
||||
IMAGES_EDIT_GEMINI_API_KEY = PersistentConfig(
|
||||
"IMAGES_EDIT_GEMINI_API_KEY",
|
||||
"images.edit.gemini.api_key",
|
||||
os.getenv("IMAGES_EDIT_GEMINI_API_KEY", GEMINI_API_KEY),
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Audio
|
||||
####################################
|
||||
|
|
|
|||
|
|
@ -163,6 +163,14 @@ from open_webui.config import (
|
|||
IMAGES_GEMINI_API_BASE_URL,
|
||||
IMAGES_GEMINI_API_KEY,
|
||||
IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||
IMAGE_EDIT_ENGINE,
|
||||
IMAGE_EDIT_MODEL,
|
||||
IMAGE_EDIT_SIZE,
|
||||
IMAGES_EDIT_OPENAI_API_BASE_URL,
|
||||
IMAGES_EDIT_OPENAI_API_KEY,
|
||||
IMAGES_EDIT_OPENAI_API_VERSION,
|
||||
IMAGES_EDIT_GEMINI_API_BASE_URL,
|
||||
IMAGES_EDIT_GEMINI_API_KEY,
|
||||
# Audio
|
||||
AUDIO_STT_ENGINE,
|
||||
AUDIO_STT_MODEL,
|
||||
|
|
@ -1078,7 +1086,6 @@ app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
|||
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
||||
app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD
|
||||
|
||||
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
|
||||
app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS
|
||||
|
|
@ -1089,6 +1096,16 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
|||
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
||||
|
||||
|
||||
app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
|
||||
app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
|
||||
app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
|
||||
app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = IMAGES_EDIT_OPENAI_API_BASE_URL
|
||||
app.state.config.IMAGES_EDIT_OPENAI_API_KEY = IMAGES_EDIT_OPENAI_API_KEY
|
||||
app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = IMAGES_EDIT_OPENAI_API_VERSION
|
||||
app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = IMAGES_EDIT_GEMINI_API_BASE_URL
|
||||
app.state.config.IMAGES_EDIT_GEMINI_API_KEY = IMAGES_EDIT_GEMINI_API_KEY
|
||||
|
||||
|
||||
########################################
|
||||
#
|
||||
# AUDIO
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
|
@ -10,18 +11,13 @@ from typing import Optional
|
|||
|
||||
from urllib.parse import quote
|
||||
import requests
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||
from open_webui.routers.files import upload_file_handler
|
||||
from open_webui.routers.files import upload_file_handler, get_file_content_by_id
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
from open_webui.utils.images.comfyui import (
|
||||
|
|
@ -121,6 +117,16 @@ class ImagesConfig(BaseModel):
|
|||
IMAGES_GEMINI_API_KEY: str
|
||||
IMAGES_GEMINI_ENDPOINT_METHOD: str
|
||||
|
||||
IMAGE_EDIT_ENGINE: str
|
||||
IMAGE_EDIT_MODEL: str
|
||||
IMAGE_EDIT_SIZE: Optional[str]
|
||||
|
||||
IMAGES_EDIT_OPENAI_API_BASE_URL: str
|
||||
IMAGES_EDIT_OPENAI_API_KEY: str
|
||||
IMAGES_EDIT_OPENAI_API_VERSION: str
|
||||
IMAGES_EDIT_GEMINI_API_BASE_URL: str
|
||||
IMAGES_EDIT_GEMINI_API_KEY: str
|
||||
|
||||
|
||||
@router.get("/config", response_model=ImagesConfig)
|
||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
|
|
@ -144,6 +150,14 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
|||
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
||||
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
||||
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
||||
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
|
||||
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
|
||||
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
|
||||
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
|
||||
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -152,6 +166,8 @@ async def update_config(
|
|||
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
|
||||
|
||||
# Create Image
|
||||
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
|
||||
form_data.ENABLE_IMAGE_PROMPT_GENERATION
|
||||
)
|
||||
|
|
@ -215,6 +231,28 @@ async def update_config(
|
|||
form_data.IMAGES_GEMINI_ENDPOINT_METHOD
|
||||
)
|
||||
|
||||
# Edit Image
|
||||
request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
|
||||
request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
|
||||
request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
|
||||
|
||||
request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
|
||||
form_data.IMAGES_OPENAI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
|
||||
form_data.IMAGES_OPENAI_API_KEY
|
||||
)
|
||||
request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
|
||||
form_data.IMAGES_EDIT_OPENAI_API_VERSION
|
||||
)
|
||||
|
||||
request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = (
|
||||
form_data.IMAGES_EDIT_GEMINI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = (
|
||||
form_data.IMAGES_EDIT_GEMINI_API_KEY
|
||||
)
|
||||
|
||||
return {
|
||||
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
|
||||
|
|
@ -235,6 +273,14 @@ async def update_config(
|
|||
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
||||
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
||||
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
||||
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
|
||||
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
|
||||
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
|
||||
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
|
||||
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -674,3 +720,255 @@ async def image_generations(
|
|||
if "error" in data:
|
||||
error = data["error"]["message"]
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
|
||||
|
||||
|
||||
class EditImageForm(BaseModel):
|
||||
image: str | list[str] # base64-encoded image(s) or URL(s)
|
||||
prompt: str
|
||||
model: Optional[str] = None
|
||||
size: Optional[str] = None
|
||||
n: Optional[int] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/edit")
|
||||
async def image_edits(
|
||||
request: Request,
|
||||
form_data: EditImageForm,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
|
||||
size = None
|
||||
width, height = None, None
|
||||
if (
|
||||
request.app.state.config.IMAGE_EDIT_SIZE
|
||||
and "x" in request.app.state.config.IMAGE_EDIT_SIZE
|
||||
) or (form_data.size and "x" in form_data.size):
|
||||
size = (
|
||||
form_data.size
|
||||
if form_data.size
|
||||
else request.app.state.config.IMAGE_EDIT_SIZE
|
||||
)
|
||||
width, height = tuple(map(int, size.split("x")))
|
||||
|
||||
model = (
|
||||
request.app.state.config.IMAGE_EDIT_MODEL
|
||||
if form_data.model is None
|
||||
else form_data.model
|
||||
)
|
||||
|
||||
def load_url_image(string):
|
||||
if string.startswith("http://") or string.startswith("https://"):
|
||||
r = requests.get(string)
|
||||
r.raise_for_status()
|
||||
image_data = base64.b64encode(r.content).decode("utf-8")
|
||||
return f"data:{r.headers['content-type']};base64,{image_data}"
|
||||
|
||||
elif string.startswith("/api/v1/files"):
|
||||
file_id = string.split("/api/v1/files/")[1].split("/content")[0]
|
||||
file_response = get_file_content_by_id(file_id, user)
|
||||
|
||||
if isinstance(file_response, FileResponse):
|
||||
file_bytes = file_response.body
|
||||
mime_type = file_response.headers.get("content-type", "image/png")
|
||||
image_data = base64.b64encode(file_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{image_data}"
|
||||
return string
|
||||
|
||||
# Load image(s) from URL(s) if necessary
|
||||
if isinstance(form_data.image, str):
|
||||
form_data.image = load_url_image(form_data.image)
|
||||
elif isinstance(form_data.image, list):
|
||||
form_data.image = [load_url_image(img) for img in form_data.image]
|
||||
|
||||
r = None
|
||||
try:
|
||||
if request.app.state.config.IMAGE_EDIT_ENGINE == "openai":
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}",
|
||||
}
|
||||
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": form_data.prompt,
|
||||
**({"n": form_data.n} if form_data.n else {}),
|
||||
**({"size": size} if size else {}),
|
||||
**(
|
||||
{}
|
||||
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else {"response_format": "b64_json"}
|
||||
),
|
||||
}
|
||||
|
||||
def get_image_file_item(base64_string):
|
||||
data = base64_string
|
||||
header, encoded = data.split(",", 1)
|
||||
mime_type = header.split(";")[0].lstrip("data:")
|
||||
image_data = base64.b64decode(encoded)
|
||||
return (
|
||||
"image",
|
||||
(
|
||||
f"{uuid.uuid4()}.png",
|
||||
io.BytesIO(image_data),
|
||||
mime_type if mime_type else "image/png",
|
||||
),
|
||||
)
|
||||
|
||||
files = []
|
||||
if isinstance(form_data.image, str):
|
||||
files = [get_image_file_item(form_data.image)]
|
||||
elif isinstance(form_data.image, list):
|
||||
for img in form_data.image:
|
||||
files.append(get_image_file_item(img))
|
||||
|
||||
url_search_params = ""
|
||||
if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
|
||||
url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}"
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/edits{url_search_params}",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data=data,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["data"]:
|
||||
if image_url := image.get("url", None):
|
||||
image_data, content_type = get_image_data(image_url, headers)
|
||||
else:
|
||||
image_data, content_type = get_image_data(image["b64_json"])
|
||||
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
model = f"{model}:generateContent"
|
||||
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
|
||||
|
||||
if isinstance(form_data.image, str):
|
||||
data["contents"][0]["parts"].append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/png",
|
||||
"data": form_data.image.split(",", 1)[1],
|
||||
}
|
||||
}
|
||||
)
|
||||
elif isinstance(form_data.image, list):
|
||||
data["contents"][0]["parts"].extend(
|
||||
[
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/png",
|
||||
"data": image.split(",", 1)[1],
|
||||
}
|
||||
}
|
||||
for image in form_data.image
|
||||
]
|
||||
)
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["candidates"]:
|
||||
for part in image["content"]["parts"]:
|
||||
if part.get("inlineData", {}).get("data"):
|
||||
image_data, content_type = get_image_data(
|
||||
part["inlineData"]["data"]
|
||||
)
|
||||
url = upload_image(
|
||||
request, image_data, content_type, data, user
|
||||
)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"n": form_data.n,
|
||||
}
|
||||
|
||||
if request.app.state.config.IMAGE_EDIT_STEPS is not None:
|
||||
data["steps"] = request.app.state.config.IMAGE_EDIT_STEPS
|
||||
|
||||
if form_data.negative_prompt is not None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
|
||||
form_data = ComfyUICreateImageForm(
|
||||
**{
|
||||
"workflow": ComfyUIWorkflow(
|
||||
**{
|
||||
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
}
|
||||
),
|
||||
**data,
|
||||
}
|
||||
)
|
||||
res = await comfyui_create_image(
|
||||
model,
|
||||
form_data,
|
||||
user.id,
|
||||
request.app.state.config.COMFYUI_BASE_URL,
|
||||
request.app.state.config.COMFYUI_API_KEY,
|
||||
)
|
||||
log.debug(f"res: {res}")
|
||||
|
||||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
headers = None
|
||||
if request.app.state.config.COMFYUI_API_KEY:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
image_data, content_type = get_image_data(image["url"], headers)
|
||||
url = upload_image(
|
||||
request,
|
||||
image_data,
|
||||
content_type,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
except Exception as e:
|
||||
error = e
|
||||
if r != None:
|
||||
data = r.text
|
||||
try:
|
||||
data = json.loads(data)
|
||||
if "error" in data:
|
||||
error = data["error"]["message"]
|
||||
except Exception:
|
||||
error = data
|
||||
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ from open_webui.routers.retrieval import (
|
|||
from open_webui.routers.images import (
|
||||
image_generations,
|
||||
CreateImageForm,
|
||||
upload_image,
|
||||
image_edits,
|
||||
EditImageForm,
|
||||
)
|
||||
from open_webui.routers.pipelines import (
|
||||
process_pipeline_inlet_filter,
|
||||
|
|
@ -717,9 +718,31 @@ async def chat_web_search_handler(
|
|||
return form_data
|
||||
|
||||
|
||||
def get_last_images(message_list):
|
||||
images = []
|
||||
for message in reversed(message_list):
|
||||
images_flag = False
|
||||
for file in message.get("files", []):
|
||||
if file.get("type") == "image":
|
||||
images.append(file.get("url"))
|
||||
images_flag = True
|
||||
|
||||
if images_flag:
|
||||
break
|
||||
|
||||
return images
|
||||
|
||||
|
||||
async def chat_image_generation_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
metadata = extra_params.get("__metadata__", {})
|
||||
chat_id = metadata.get("chat_id", None)
|
||||
if not chat_id:
|
||||
return form_data
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
|
||||
|
||||
__event_emitter__ = extra_params["__event_emitter__"]
|
||||
await __event_emitter__(
|
||||
{
|
||||
|
|
@ -728,19 +751,24 @@ async def chat_image_generation_handler(
|
|||
}
|
||||
)
|
||||
|
||||
messages = form_data["messages"]
|
||||
user_message = get_last_user_message(messages)
|
||||
messages_map = chat.chat.get("history", {}).get("messages", {})
|
||||
message_id = chat.chat.get("history", {}).get("currentId")
|
||||
message_list = get_message_list(messages_map, message_id)
|
||||
user_message = get_last_user_message(message_list)
|
||||
|
||||
prompt = user_message
|
||||
negative_prompt = ""
|
||||
input_images = get_last_images(message_list)
|
||||
|
||||
system_message_content = ""
|
||||
if len(input_images) == 0:
|
||||
# Create image(s)
|
||||
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
|
||||
try:
|
||||
res = await generate_image_prompt(
|
||||
request,
|
||||
{
|
||||
"model": form_data["model"],
|
||||
"messages": messages,
|
||||
"messages": form_data["messages"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
|
@ -764,8 +792,6 @@ async def chat_image_generation_handler(
|
|||
log.exception(e)
|
||||
prompt = user_message
|
||||
|
||||
system_message_content = ""
|
||||
|
||||
try:
|
||||
images = await image_generations(
|
||||
request=request,
|
||||
|
|
@ -795,9 +821,17 @@ async def chat_image_generation_handler(
|
|||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
|
||||
system_message_content = "<context>The requested image has been created and is now being shown to the user. Let them know that it has been generated.</context>"
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.debug(e)
|
||||
|
||||
error_message = ""
|
||||
if isinstance(e, HTTPException):
|
||||
if e.detail and isinstance(e.detail, dict):
|
||||
error_message = e.detail.get("message", str(e.detail))
|
||||
else:
|
||||
error_message = str(e.detail)
|
||||
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
|
|
@ -808,7 +842,60 @@ async def chat_image_generation_handler(
|
|||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
|
||||
system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>"
|
||||
else:
|
||||
# Edit image(s)
|
||||
try:
|
||||
images = await image_edits(
|
||||
request=request,
|
||||
form_data=EditImageForm(**{"prompt": prompt, "image": input_images}),
|
||||
user=user,
|
||||
)
|
||||
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {"description": "Image created", "done": True},
|
||||
}
|
||||
)
|
||||
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "files",
|
||||
"data": {
|
||||
"files": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": image["url"],
|
||||
}
|
||||
for image in images
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>The requested image has been created and is now being shown to the user. Let them know that it has been generated.</context>"
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
||||
error_message = ""
|
||||
if isinstance(e, HTTPException):
|
||||
if e.detail and isinstance(e.detail, dict):
|
||||
error_message = e.detail.get("message", str(e.detail))
|
||||
else:
|
||||
error_message = str(e.detail)
|
||||
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"description": f"An error occurred while generating an image",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>"
|
||||
|
||||
if system_message_content:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@
|
|||
let config = null;
|
||||
|
||||
let showComfyUIWorkflowEditor = false;
|
||||
let requiredWorkflowNodes = [
|
||||
let REQUIRED_WORKFLOW_NODES = [
|
||||
{
|
||||
type: 'prompt',
|
||||
key: 'text',
|
||||
|
|
@ -62,6 +62,29 @@
|
|||
}
|
||||
];
|
||||
|
||||
let REQUIRED_EDIT_WORKFLOW_NODES = [
|
||||
{
|
||||
type: 'prompt',
|
||||
key: 'text',
|
||||
node_ids: ''
|
||||
},
|
||||
{
|
||||
type: 'model',
|
||||
key: 'ckpt_name',
|
||||
node_ids: ''
|
||||
},
|
||||
{
|
||||
type: 'width',
|
||||
key: 'width',
|
||||
node_ids: ''
|
||||
},
|
||||
{
|
||||
type: 'height',
|
||||
key: 'height',
|
||||
node_ids: ''
|
||||
}
|
||||
];
|
||||
|
||||
const getModels = async () => {
|
||||
models = await getImageGenerationModels(localStorage.token).catch((error) => {
|
||||
toast.error(`${error}`);
|
||||
|
|
@ -137,7 +160,7 @@
|
|||
}
|
||||
|
||||
if (config?.COMFYUI_WORKFLOW) {
|
||||
config.COMFYUI_WORKFLOW_NODES = requiredWorkflowNodes.map((node) => {
|
||||
config.COMFYUI_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => {
|
||||
return {
|
||||
type: node.type,
|
||||
key: node.key,
|
||||
|
|
@ -178,7 +201,7 @@
|
|||
}
|
||||
}
|
||||
|
||||
requiredWorkflowNodes = requiredWorkflowNodes.map((node) => {
|
||||
REQUIRED_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => {
|
||||
const n = config.COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node;
|
||||
console.debug(n);
|
||||
|
||||
|
|
@ -665,7 +688,7 @@
|
|||
</div>
|
||||
|
||||
<div class="mt-1 text-xs flex flex-col gap-1.5">
|
||||
{#each requiredWorkflowNodes as node}
|
||||
{#each REQUIRED_WORKFLOW_NODES as node}
|
||||
<div class="flex w-full flex-col">
|
||||
<div class="shrink-0">
|
||||
<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
|
||||
|
|
@ -791,13 +814,13 @@
|
|||
placeholder={$i18n.t('Select Engine')}
|
||||
>
|
||||
<option value="openai">{$i18n.t('Default (Open AI)')}</option>
|
||||
<option value="comfyui">{$i18n.t('ComfyUI')}</option>
|
||||
<option value="comfyui">{$i18n.t('Gemini')}</option>
|
||||
<!-- <option value="comfyui">{$i18n.t('ComfyUI')}</option> -->
|
||||
<option value="gemini">{$i18n.t('Gemini')}</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if config.ENABLE_IMAGE_EDIT}
|
||||
{#if config.ENABLE_IMAGE_GENERATION}
|
||||
<div class="mb-2.5">
|
||||
<div class="flex w-full justify-between items-center">
|
||||
<div class="text-xs pr-2">
|
||||
|
|
@ -918,7 +941,7 @@
|
|||
<input
|
||||
class="w-full text-sm bg-transparent outline-hidden text-right"
|
||||
placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
|
||||
bind:value={config.COMFYUI_BASE_URL}
|
||||
bind:value={config.IMAGES_EDIT_COMFYUI_BASE_URL}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
|
|
@ -967,7 +990,7 @@
|
|||
<SensitiveInput
|
||||
inputClassName="text-right w-full"
|
||||
placeholder={$i18n.t('sk-1234')}
|
||||
bind:value={config.COMFYUI_API_KEY}
|
||||
bind:value={config.IMAGES_EDIT_COMFYUI_API_KEY}
|
||||
required={false}
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -977,7 +1000,7 @@
|
|||
|
||||
<div class="mb-2.5">
|
||||
<input
|
||||
id="upload-comfyui-workflow-input"
|
||||
id="upload-comfyui-edit-workflow-input"
|
||||
hidden
|
||||
type="file"
|
||||
accept=".json"
|
||||
|
|
@ -986,7 +1009,7 @@
|
|||
const reader = new FileReader();
|
||||
|
||||
reader.onload = (e) => {
|
||||
config.COMFYUI_WORKFLOW = e.target.result;
|
||||
config.IMAGES_EDIT_COMFYUI_WORKFLOW = e.target.result;
|
||||
e.target.value = null;
|
||||
};
|
||||
|
||||
|
|
@ -1002,7 +1025,7 @@
|
|||
|
||||
<div class="flex w-full">
|
||||
<div class="flex-1 mr-2 justify-end flex gap-1">
|
||||
{#if config.COMFYUI_WORKFLOW}
|
||||
{#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
|
||||
<button
|
||||
class="text-xs text-gray-700 dark:text-gray-400 underline"
|
||||
type="button"
|
||||
|
|
@ -1022,7 +1045,7 @@
|
|||
type="button"
|
||||
aria-label={$i18n.t('Click here to upload a workflow.json file.')}
|
||||
on:click={() => {
|
||||
document.getElementById('upload-comfyui-workflow-input')?.click();
|
||||
document.getElementById('upload-comfyui-edit-workflow-input')?.click();
|
||||
}}
|
||||
>
|
||||
{$i18n.t('Upload')}
|
||||
|
|
@ -1035,28 +1058,20 @@
|
|||
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||
<CodeEditorModal
|
||||
bind:show={showComfyUIWorkflowEditor}
|
||||
value={config.COMFYUI_WORKFLOW}
|
||||
value={config.IMAGES_EDIT_COMFYUI_WORKFLOW}
|
||||
lang="json"
|
||||
onChange={(e) => {
|
||||
config.COMFYUI_WORKFLOW = e;
|
||||
config.IMAGES_EDIT_COMFYUI_WORKFLOW = e;
|
||||
}}
|
||||
onSave={() => {
|
||||
console.log('Saved');
|
||||
}}
|
||||
/>
|
||||
<!-- {#if config.COMFYUI_WORKFLOW}
|
||||
<Textarea
|
||||
class="w-full rounded-lg my-1 py-2 px-3 text-xs bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden disabled:text-gray-600 resize-none"
|
||||
rows="10"
|
||||
bind:value={config.COMFYUI_WORKFLOW}
|
||||
required
|
||||
/>
|
||||
{/if} -->
|
||||
{$i18n.t('Make sure to export a workflow.json file as API format from ComfyUI.')}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if config.COMFYUI_WORKFLOW}
|
||||
{#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
|
||||
<div class="mb-2.5">
|
||||
<div class="flex w-full justify-between items-center">
|
||||
<div class="text-xs pr-2 shrink-0">
|
||||
|
|
@ -1067,7 +1082,7 @@
|
|||
</div>
|
||||
|
||||
<div class="mt-1 text-xs flex flex-col gap-1.5">
|
||||
{#each requiredWorkflowNodes as node}
|
||||
{#each REQUIRED_EDIT_WORKFLOW_NODES as node}
|
||||
<div class="flex w-full flex-col">
|
||||
<div class="shrink-0">
|
||||
<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
|
||||
|
|
@ -1111,6 +1126,47 @@
|
|||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if config?.IMAGE_GENERATION_ENGINE === 'gemini'}
|
||||
<div class="mb-2.5">
|
||||
<div class="flex w-full justify-between items-center">
|
||||
<div class="text-xs pr-2 shrink-0">
|
||||
<div class="">
|
||||
{$i18n.t('Gemini Base URL')}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex w-full">
|
||||
<div class="flex-1">
|
||||
<input
|
||||
class="w-full text-sm bg-transparent outline-hidden text-right"
|
||||
placeholder={$i18n.t('API Base URL')}
|
||||
bind:value={config.IMAGES_EDIT_GEMINI_API_BASE_URL}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-2.5">
|
||||
<div class="flex w-full justify-between items-center">
|
||||
<div class="text-xs pr-2 shrink-0">
|
||||
<div class="">
|
||||
{$i18n.t('Gemini API Key')}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex w-full">
|
||||
<div class="flex-1">
|
||||
<SensitiveInput
|
||||
inputClassName="text-right w-full"
|
||||
placeholder={$i18n.t('API Key')}
|
||||
bind:value={config.IMAGES_EDIT_GEMINI_API_KEY}
|
||||
required={true}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
Loading…
Reference in a new issue