open-webui/backend/open_webui/routers/images.py
2025-11-09 22:36:00 -05:00

1039 lines
40 KiB
Python

import asyncio
import base64
import uuid
import io
import json
import logging
import mimetypes
import re
from pathlib import Path
from typing import Optional
from urllib.parse import quote
import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from fastapi.responses import FileResponse
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file_handler, get_file_content_by_id
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.images.comfyui import (
ComfyUICreateImageForm,
ComfyUIEditImageForm,
ComfyUIWorkflow,
comfyui_upload_image,
comfyui_create_image,
comfyui_edit_image,
)
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
router = APIRouter()
def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth(request)
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth},
)
return request.app.state.config.IMAGE_GENERATION_MODEL
def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "imagen-3.0-generate-002"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else ""
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
options = r.json()
return options["sd_model_checkpoint"]
except Exception as e:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class ImagesConfig(BaseModel):
ENABLE_IMAGE_GENERATION: bool
ENABLE_IMAGE_PROMPT_GENERATION: bool
IMAGE_GENERATION_ENGINE: str
IMAGE_GENERATION_MODEL: str
IMAGE_SIZE: Optional[str]
IMAGE_STEPS: Optional[int]
IMAGES_OPENAI_API_BASE_URL: str
IMAGES_OPENAI_API_KEY: str
IMAGES_OPENAI_API_VERSION: str
AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str
AUTOMATIC1111_PARAMS: Optional[dict | str]
COMFYUI_BASE_URL: str
COMFYUI_API_KEY: str
COMFYUI_WORKFLOW: str
COMFYUI_WORKFLOW_NODES: list[dict]
IMAGES_GEMINI_API_BASE_URL: str
IMAGES_GEMINI_API_KEY: str
IMAGES_GEMINI_ENDPOINT_METHOD: str
IMAGE_EDIT_ENGINE: str
IMAGE_EDIT_MODEL: str
IMAGE_EDIT_SIZE: Optional[str]
IMAGES_EDIT_OPENAI_API_BASE_URL: str
IMAGES_EDIT_OPENAI_API_KEY: str
IMAGES_EDIT_OPENAI_API_VERSION: str
IMAGES_EDIT_GEMINI_API_BASE_URL: str
IMAGES_EDIT_GEMINI_API_KEY: str
IMAGES_EDIT_COMFYUI_BASE_URL: str
IMAGES_EDIT_COMFYUI_API_KEY: str
IMAGES_EDIT_COMFYUI_WORKFLOW: str
IMAGES_EDIT_COMFYUI_WORKFLOW_NODES: list[dict]
@router.get("/config", response_model=ImagesConfig)
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
"IMAGES_EDIT_COMFYUI_BASE_URL": request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
"IMAGES_EDIT_COMFYUI_API_KEY": request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
"IMAGES_EDIT_COMFYUI_WORKFLOW": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW,
"IMAGES_EDIT_COMFYUI_WORKFLOW_NODES": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES,
}
@router.post("/config/update")
async def update_config(
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
# Create Image
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
form_data.ENABLE_IMAGE_PROMPT_GENERATION
)
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
if (
form_data.IMAGE_SIZE == "auto"
and form_data.IMAGE_GENERATION_MODEL != "gpt-image-1"
):
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(
" (auto is only allowed with gpt-image-1)."
),
)
pattern = r"^\d+x\d+$"
if (
form_data.IMAGE_SIZE == "auto"
or form_data.IMAGE_SIZE == ""
or re.match(pattern, form_data.IMAGE_SIZE)
):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
)
if form_data.IMAGE_STEPS >= 0:
request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
)
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
form_data.IMAGES_OPENAI_API_BASE_URL
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.IMAGES_OPENAI_API_KEY
request.app.state.config.IMAGES_OPENAI_API_VERSION = (
form_data.IMAGES_OPENAI_API_VERSION
)
request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL
request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
request.app.state.config.AUTOMATIC1111_PARAMS = form_data.AUTOMATIC1111_PARAMS
request.app.state.config.COMFYUI_BASE_URL = form_data.COMFYUI_BASE_URL.strip("/")
request.app.state.config.COMFYUI_API_KEY = form_data.COMFYUI_API_KEY
request.app.state.config.COMFYUI_WORKFLOW = form_data.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = form_data.COMFYUI_WORKFLOW_NODES
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
form_data.IMAGES_GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.IMAGES_GEMINI_API_KEY
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = (
form_data.IMAGES_GEMINI_ENDPOINT_METHOD
)
# Edit Image
request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
form_data.IMAGES_EDIT_OPENAI_API_BASE_URL
)
request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
form_data.IMAGES_EDIT_OPENAI_API_KEY
)
request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
form_data.IMAGES_EDIT_OPENAI_API_VERSION
)
request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = (
form_data.IMAGES_EDIT_GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = (
form_data.IMAGES_EDIT_GEMINI_API_KEY
)
request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL = (
form_data.IMAGES_EDIT_COMFYUI_BASE_URL.strip("/")
)
request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY = (
form_data.IMAGES_EDIT_COMFYUI_API_KEY
)
request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW = (
form_data.IMAGES_EDIT_COMFYUI_WORKFLOW
)
request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = (
form_data.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES
)
return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
"IMAGES_EDIT_COMFYUI_BASE_URL": request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
"IMAGES_EDIT_COMFYUI_API_KEY": request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
"IMAGES_EDIT_COMFYUI_WORKFLOW": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW,
"IMAGES_EDIT_COMFYUI_WORKFLOW_NODES": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES,
}
def get_automatic1111_api_auth(request: Request):
if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
return ""
else:
auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
"utf-8"
)
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}"
@router.get("/config/url/verify")
async def verify_url(request: Request, user=Depends(get_admin_user)):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
r.raise_for_status()
return True
except Exception:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
try:
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
r.raise_for_status()
return True
except Exception:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else:
return True
@router.get("/models")
def get_models(request: Request, user=Depends(get_verified_user)):
try:
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
info = r.json()
workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
model_node_id = None
for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
if node["type"] == "model":
if node["node_ids"]:
model_node_id = node["node_ids"][0]
break
if model_node_id:
model_list_key = None
log.info(workflow[model_node_id]["class_type"])
for key in info[workflow[model_node_id]["class_type"]]["input"][
"required"
]:
if "_name" in key:
model_list_key = key
break
if model_list_key:
return list(
map(
lambda model: {"id": model, "name": model},
info[workflow[model_node_id]["class_type"]]["input"][
"required"
][model_list_key][0],
)
)
else:
return list(
map(
lambda model: {"id": model, "name": model},
info["CheckpointLoaderSimple"]["input"]["required"][
"ckpt_name"
][0],
)
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth(request)},
)
models = r.json()
return list(
map(
lambda model: {"id": model["title"], "name": model["model_name"]},
models,
)
)
except Exception as e:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class CreateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
size: Optional[str] = None
n: int = 1
negative_prompt: Optional[str] = None
GenerateImageForm = CreateImageForm # Alias for backward compatibility
def get_image_data(data: str, headers=None):
try:
if data.startswith("http://") or data.startswith("https://"):
if headers:
r = requests.get(data, headers=headers)
else:
r = requests.get(data)
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
return r.content, mime_type
else:
log.error("Url does not point to an image.")
return None
else:
if "," in data:
header, encoded = data.split(",", 1)
mime_type = header.split(";")[0].lstrip("data:")
img_data = base64.b64decode(encoded)
else:
mime_type = "image/png"
img_data = base64.b64decode(data)
return img_data, mime_type
except Exception as e:
log.exception(f"Error loading image data: {e}")
return None, None
def upload_image(request, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
headers={
"content-type": content_type,
},
)
file_item = upload_file_handler(
request,
file=file,
metadata=metadata,
process=False,
user=user,
)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@router.post("/generations")
async def image_generations(
request: Request,
form_data: CreateImageForm,
user=Depends(get_verified_user),
):
# if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
# image model other than gpt-image-1, which is warned about on settings save
size = "512x512"
if (
request.app.state.config.IMAGE_SIZE
and "x" in request.app.state.config.IMAGE_SIZE
):
size = request.app.state.config.IMAGE_SIZE
if form_data.size and "x" in form_data.size:
size = form_data.size
width, height = tuple(map(int, size.split("x")))
model = get_image_model(request)
r = None
try:
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
headers = {
"Authorization": f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}",
"Content-Type": "application/json",
}
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers = include_user_info_headers(headers, user)
data = {
"model": model,
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
form_data.size
if form_data.size
else request.app.state.config.IMAGE_SIZE
),
**(
{}
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
else {"response_format": "b64_json"}
),
}
api_version_query_param = ""
if request.app.state.config.IMAGES_OPENAI_API_VERSION:
api_version_query_param = (
f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
)
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["data"]:
if image_url := image.get("url", None):
image_data, content_type = get_image_data(image_url, headers)
else:
image_data, content_type = get_image_data(image["b64_json"])
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {
"Content-Type": "application/json",
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
}
data = {}
if (
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == ""
or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "predict"
):
model = f"{model}:predict"
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
}
elif (
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD
== "generateContent"
):
model = f"{model}:generateContent"
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
if model.endswith(":predict"):
for image in res["predictions"]:
image_data, content_type = get_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
elif model.endswith(":generateContent"):
for image in res["candidates"]:
for part in image["content"]["parts"]:
if part.get("inlineData", {}).get("data"):
image_data, content_type = get_image_data(
part["inlineData"]["data"]
)
url = upload_image(
request, image_data, content_type, data, user
)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
form_data = ComfyUICreateImageForm(
**{
"workflow": ComfyUIWorkflow(
**{
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}
),
**data,
}
)
res = await comfyui_create_image(
model,
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_API_KEY,
)
log.debug(f"res: {res}")
images = []
for image in res["data"]:
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
image_data, content_type = get_image_data(image["url"], headers)
url = upload_image(
request,
image_data,
content_type,
form_data.model_dump(exclude_none=True),
user,
)
images.append({"url": url})
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
if form_data.model:
set_image_model(request, form_data.model)
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if request.app.state.config.AUTOMATIC1111_PARAMS:
data = {**data, **request.app.state.config.AUTOMATIC1111_PARAMS}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
headers={"authorization": get_automatic1111_api_auth(request)},
)
res = r.json()
log.debug(f"res: {res}")
images = []
for image in res["images"]:
image_data, content_type = get_image_data(image)
url = upload_image(
request,
image_data,
content_type,
{**data, "info": res["info"]},
user,
)
images.append({"url": url})
return images
except Exception as e:
error = e
if r != None:
data = r.json()
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
class EditImageForm(BaseModel):
image: str | list[str] # base64-encoded image(s) or URL(s)
prompt: str
model: Optional[str] = None
size: Optional[str] = None
n: Optional[int] = None
negative_prompt: Optional[str] = None
@router.post("/edit")
async def image_edits(
request: Request,
form_data: EditImageForm,
user=Depends(get_verified_user),
):
size = None
width, height = None, None
if (
request.app.state.config.IMAGE_EDIT_SIZE
and "x" in request.app.state.config.IMAGE_EDIT_SIZE
) or (form_data.size and "x" in form_data.size):
size = (
form_data.size
if form_data.size
else request.app.state.config.IMAGE_EDIT_SIZE
)
width, height = tuple(map(int, size.split("x")))
model = (
request.app.state.config.IMAGE_EDIT_MODEL
if form_data.model is None
else form_data.model
)
try:
async def load_url_image(data):
if data.startswith("http://") or data.startswith("https://"):
r = await asyncio.to_thread(requests.get, data)
r.raise_for_status()
image_data = base64.b64encode(r.content).decode("utf-8")
return f"data:{r.headers['content-type']};base64,{image_data}"
elif data.startswith("/api/v1/files"):
file_id = data.split("/api/v1/files/")[1].split("/content")[0]
file_response = await get_file_content_by_id(file_id, user)
if isinstance(file_response, FileResponse):
file_path = file_response.path
with open(file_path, "rb") as f:
file_bytes = f.read()
image_data = base64.b64encode(file_bytes).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
return f"data:{mime_type};base64,{image_data}"
return data
# Load image(s) from URL(s) if necessary
if isinstance(form_data.image, str):
form_data.image = await load_url_image(form_data.image)
elif isinstance(form_data.image, list):
form_data.image = [await load_url_image(img) for img in form_data.image]
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
def get_image_file_item(base64_string):
data = base64_string
header, encoded = data.split(",", 1)
mime_type = header.split(";")[0].lstrip("data:")
image_data = base64.b64decode(encoded)
return (
"image",
(
f"{uuid.uuid4()}.png",
io.BytesIO(image_data),
mime_type if mime_type else "image/png",
),
)
r = None
try:
if request.app.state.config.IMAGE_EDIT_ENGINE == "openai":
headers = {
"Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers = include_user_info_headers(headers, user)
data = {
"model": model,
"prompt": form_data.prompt,
**({"n": form_data.n} if form_data.n else {}),
**({"size": size} if size else {}),
**(
{}
if "gpt-image-1" in request.app.state.config.IMAGE_EDIT_MODEL
else {"response_format": "b64_json"}
),
}
files = []
if isinstance(form_data.image, str):
files = [get_image_file_item(form_data.image)]
elif isinstance(form_data.image, list):
for img in form_data.image:
files.append(get_image_file_item(img))
url_search_params = ""
if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}"
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL}/images/edits{url_search_params}",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["data"]:
if image_url := image.get("url", None):
image_data, content_type = get_image_data(image_url, headers)
else:
image_data, content_type = get_image_data(image["b64_json"])
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_EDIT_ENGINE == "gemini":
headers = {
"Content-Type": "application/json",
"x-goog-api-key": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
}
model = f"{model}:generateContent"
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
if isinstance(form_data.image, str):
data["contents"][0]["parts"].append(
{
"inline_data": {
"mime_type": "image/png",
"data": form_data.image.split(",", 1)[1],
}
}
)
elif isinstance(form_data.image, list):
data["contents"][0]["parts"].extend(
[
{
"inline_data": {
"mime_type": "image/png",
"data": image.split(",", 1)[1],
}
}
for image in form_data.image
]
)
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL}/models/{model}",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["candidates"]:
for part in image["content"]["parts"]:
if part.get("inlineData", {}).get("data"):
image_data, content_type = get_image_data(
part["inlineData"]["data"]
)
url = upload_image(
request, image_data, content_type, data, user
)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_EDIT_ENGINE == "comfyui":
try:
files = []
if isinstance(form_data.image, str):
files = [get_image_file_item(form_data.image)]
elif isinstance(form_data.image, list):
for img in form_data.image:
files.append(get_image_file_item(img))
# Upload images to ComfyUI and get their names
comfyui_images = []
for file_item in files:
res = await comfyui_upload_image(
file_item,
request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
)
comfyui_images.append(res.get("name", file_item[1][0]))
except Exception as e:
log.debug(f"Error uploading images to ComfyUI: {e}")
raise Exception("Failed to upload images to ComfyUI.")
data = {
"image": comfyui_images,
"prompt": form_data.prompt,
**({"width": width} if width is not None else {}),
**({"height": height} if height is not None else {}),
**({"n": form_data.n} if form_data.n else {}),
}
form_data = ComfyUIEditImageForm(
**{
"workflow": ComfyUIWorkflow(
**{
"workflow": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW,
"nodes": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES,
}
),
**data,
}
)
res = await comfyui_edit_image(
model,
form_data,
user.id,
request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
)
log.debug(f"res: {res}")
image_urls = set()
for image in res["data"]:
image_urls.add(image["url"])
image_urls = list(image_urls)
# Prioritize output type URLs if available
output_type_urls = [url for url in image_urls if "type=output" in url]
if output_type_urls:
image_urls = output_type_urls
log.debug(f"Image URLs: {image_urls}")
images = []
for image_url in image_urls:
headers = None
if request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY}"
}
image_data, content_type = get_image_data(image_url, headers)
url = upload_image(
request,
image_data,
content_type,
form_data.model_dump(exclude_none=True),
user,
)
images.append({"url": url})
return images
except Exception as e:
error = e
if r != None:
data = r.text
try:
data = json.loads(data)
if "error" in data:
error = data["error"]["message"]
except Exception:
error = data
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))