open-webui/backend/open_webui/routers/images.py
2025-11-05 01:59:16 -05:00

676 lines
25 KiB
Python

import asyncio
import base64
import io
import json
import logging
import mimetypes
import re
from pathlib import Path
from typing import Optional
from urllib.parse import quote
import requests
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
UploadFile,
)
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file_handler
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.images.comfyui import (
ComfyUICreateImageForm,
ComfyUIWorkflow,
comfyui_create_image,
)
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
router = APIRouter()
def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth(request)
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth},
)
return request.app.state.config.IMAGE_GENERATION_MODEL
def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "imagen-3.0-generate-002"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else ""
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
options = r.json()
return options["sd_model_checkpoint"]
except Exception as e:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class ImagesConfig(BaseModel):
ENABLE_IMAGE_GENERATION: bool
ENABLE_IMAGE_PROMPT_GENERATION: bool
IMAGE_GENERATION_ENGINE: str
IMAGE_GENERATION_MODEL: str
IMAGE_SIZE: Optional[str]
IMAGE_STEPS: Optional[int]
IMAGES_OPENAI_API_BASE_URL: str
IMAGES_OPENAI_API_KEY: str
IMAGES_OPENAI_API_VERSION: str
AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str
AUTOMATIC1111_PARAMS: Optional[dict | str]
COMFYUI_BASE_URL: str
COMFYUI_API_KEY: str
COMFYUI_WORKFLOW: str
COMFYUI_WORKFLOW_NODES: list[dict]
IMAGES_GEMINI_API_BASE_URL: str
IMAGES_GEMINI_API_KEY: str
IMAGES_GEMINI_ENDPOINT_METHOD: str
@router.get("/config", response_model=ImagesConfig)
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
}
@router.post("/config/update")
async def update_config(
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
form_data.ENABLE_IMAGE_PROMPT_GENERATION
)
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
if (
form_data.IMAGE_SIZE == "auto"
and form_data.IMAGE_GENERATION_MODEL != "gpt-image-1"
):
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(
" (auto is only allowed with gpt-image-1)."
),
)
pattern = r"^\d+x\d+$"
if (
form_data.IMAGE_SIZE == "auto"
or form_data.IMAGE_SIZE == ""
or re.match(pattern, form_data.IMAGE_SIZE)
):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
)
if form_data.IMAGE_STEPS >= 0:
request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
)
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
form_data.IMAGES_OPENAI_API_BASE_URL
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.IMAGES_OPENAI_API_KEY
request.app.state.config.IMAGES_OPENAI_API_VERSION = (
form_data.IMAGES_OPENAI_API_VERSION
)
request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL
request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
request.app.state.config.AUTOMATIC1111_PARAMS = form_data.AUTOMATIC1111_PARAMS
request.app.state.config.COMFYUI_BASE_URL = form_data.COMFYUI_BASE_URL.strip("/")
request.app.state.config.COMFYUI_API_KEY = form_data.COMFYUI_API_KEY
request.app.state.config.COMFYUI_WORKFLOW = form_data.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = form_data.COMFYUI_WORKFLOW_NODES
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
form_data.IMAGES_GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.IMAGES_GEMINI_API_KEY
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = (
form_data.IMAGES_GEMINI_ENDPOINT_METHOD
)
return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
}
def get_automatic1111_api_auth(request: Request):
if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
return ""
else:
auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
"utf-8"
)
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}"
@router.get("/config/url/verify")
async def verify_url(request: Request, user=Depends(get_admin_user)):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
try:
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
r.raise_for_status()
return True
except Exception:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
try:
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
r.raise_for_status()
return True
except Exception:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else:
return True
@router.get("/models")
def get_models(request: Request, user=Depends(get_verified_user)):
try:
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
info = r.json()
workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
model_node_id = None
for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
if node["type"] == "model":
if node["node_ids"]:
model_node_id = node["node_ids"][0]
break
if model_node_id:
model_list_key = None
log.info(workflow[model_node_id]["class_type"])
for key in info[workflow[model_node_id]["class_type"]]["input"][
"required"
]:
if "_name" in key:
model_list_key = key
break
if model_list_key:
return list(
map(
lambda model: {"id": model, "name": model},
info[workflow[model_node_id]["class_type"]]["input"][
"required"
][model_list_key][0],
)
)
else:
return list(
map(
lambda model: {"id": model, "name": model},
info["CheckpointLoaderSimple"]["input"]["required"][
"ckpt_name"
][0],
)
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth(request)},
)
models = r.json()
return list(
map(
lambda model: {"id": model["title"], "name": model["model_name"]},
models,
)
)
except Exception as e:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class CreateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
size: Optional[str] = None
n: int = 1
negative_prompt: Optional[str] = None
GenerateImageForm = CreateImageForm # Alias for backward compatibility
def get_image_data(data: str, headers=None):
try:
if data.startswith("http://") or data.startswith("https://"):
if headers:
r = requests.get(data, headers=headers)
else:
r = requests.get(data)
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
return r.content, mime_type
else:
log.error("Url does not point to an image.")
return None
else:
if "," in data:
header, encoded = data.split(",", 1)
mime_type = header.split(";")[0].lstrip("data:")
img_data = base64.b64decode(encoded)
else:
mime_type = "image/png"
img_data = base64.b64decode(data)
return img_data, mime_type
except Exception as e:
log.exception(f"Error loading image data: {e}")
return None, None
def upload_image(request, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
headers={
"content-type": content_type,
},
)
file_item = upload_file_handler(
request,
file=file,
metadata=metadata,
process=False,
user=user,
)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@router.post("/generations")
async def image_generations(
request: Request,
form_data: CreateImageForm,
user=Depends(get_verified_user),
):
# if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
# image model other than gpt-image-1, which is warned about on settings save
size = "512x512"
if (
request.app.state.config.IMAGE_SIZE
and "x" in request.app.state.config.IMAGE_SIZE
):
size = request.app.state.config.IMAGE_SIZE
if form_data.size and "x" in form_data.size:
size = form_data.size
width, height = tuple(map(int, size.split("x")))
model = get_image_model(request)
r = None
try:
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
headers = {
"Authorization": f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}",
"Content-Type": "application/json",
}
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers = include_user_info_headers(headers, user)
data = {
"model": model,
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
form_data.size
if form_data.size
else request.app.state.config.IMAGE_SIZE
),
**(
{}
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
else {"response_format": "b64_json"}
),
}
api_version_query_param = ""
if request.app.state.config.IMAGES_OPENAI_API_VERSION:
api_version_query_param = (
f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
)
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["data"]:
if image_url := image.get("url", None):
image_data, content_type = get_image_data(image_url, headers)
else:
image_data, content_type = get_image_data(image["b64_json"])
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {
"Content-Type": "application/json",
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
}
data = {}
if (
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == ""
or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "predict"
):
model = f"{model}:predict"
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
}
elif (
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD
== "generateContent"
):
model = f"{model}:generateContent"
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
if model.endswith(":predict"):
for image in res["predictions"]:
image_data, content_type = get_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
elif model.endswith(":generateContent"):
for image in res["candidates"]:
for part in image["content"]["parts"]:
if part.get("inlineData", {}).get("data"):
image_data, content_type = get_image_data(
part["inlineData"]["data"]
)
url = upload_image(
request, image_data, content_type, data, user
)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
form_data = ComfyUICreateImageForm(
**{
"workflow": ComfyUIWorkflow(
**{
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}
),
**data,
}
)
res = await comfyui_create_image(
model,
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_API_KEY,
)
log.debug(f"res: {res}")
images = []
for image in res["data"]:
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
image_data, content_type = get_image_data(image["url"], headers)
url = upload_image(
request,
image_data,
content_type,
form_data.model_dump(exclude_none=True),
user,
)
images.append({"url": url})
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
if form_data.model:
set_image_model(request, form_data.model)
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if request.app.state.config.AUTOMATIC1111_PARAMS:
data = {**data, **request.app.state.config.AUTOMATIC1111_PARAMS}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
headers={"authorization": get_automatic1111_api_auth(request)},
)
res = r.json()
log.debug(f"res: {res}")
images = []
for image in res["images"]:
image_data, content_type = get_image_data(image)
url = upload_image(
request,
image_data,
content_type,
{**data, "info": res["info"]},
user,
)
images.append({"url": url})
return images
except Exception as e:
error = e
if r != None:
data = r.json()
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))