mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 20:35:19 +00:00
974 lines
37 KiB
Python
974 lines
37 KiB
Python
import asyncio
|
|
import base64
|
|
import uuid
|
|
import io
|
|
import json
|
|
import logging
|
|
import mimetypes
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from urllib.parse import quote
|
|
import requests
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
|
from fastapi.responses import FileResponse
|
|
|
|
from open_webui.config import CACHE_DIR
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
|
from open_webui.routers.files import upload_file_handler, get_file_content_by_id
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.utils.headers import include_user_info_headers
|
|
from open_webui.utils.images.comfyui import (
|
|
ComfyUICreateImageForm,
|
|
ComfyUIWorkflow,
|
|
comfyui_create_image,
|
|
)
|
|
from pydantic import BaseModel
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
|
|
|
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
|
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def set_image_model(request: Request, model: str):
|
|
log.info(f"Setting image model to {model}")
|
|
request.app.state.config.IMAGE_GENERATION_MODEL = model
|
|
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
|
|
api_auth = get_automatic1111_api_auth(request)
|
|
r = requests.get(
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
headers={"authorization": api_auth},
|
|
)
|
|
options = r.json()
|
|
if model != options["sd_model_checkpoint"]:
|
|
options["sd_model_checkpoint"] = model
|
|
r = requests.post(
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
json=options,
|
|
headers={"authorization": api_auth},
|
|
)
|
|
return request.app.state.config.IMAGE_GENERATION_MODEL
|
|
|
|
|
|
def get_image_model(request):
|
|
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
|
return (
|
|
request.app.state.config.IMAGE_GENERATION_MODEL
|
|
if request.app.state.config.IMAGE_GENERATION_MODEL
|
|
else "dall-e-2"
|
|
)
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
|
return (
|
|
request.app.state.config.IMAGE_GENERATION_MODEL
|
|
if request.app.state.config.IMAGE_GENERATION_MODEL
|
|
else "imagen-3.0-generate-002"
|
|
)
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
return (
|
|
request.app.state.config.IMAGE_GENERATION_MODEL
|
|
if request.app.state.config.IMAGE_GENERATION_MODEL
|
|
else ""
|
|
)
|
|
elif (
|
|
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
|
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
|
):
|
|
try:
|
|
r = requests.get(
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
headers={"authorization": get_automatic1111_api_auth(request)},
|
|
)
|
|
options = r.json()
|
|
return options["sd_model_checkpoint"]
|
|
except Exception as e:
|
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
|
|
|
|
|
class ImagesConfig(BaseModel):
|
|
ENABLE_IMAGE_GENERATION: bool
|
|
ENABLE_IMAGE_PROMPT_GENERATION: bool
|
|
|
|
IMAGE_GENERATION_ENGINE: str
|
|
IMAGE_GENERATION_MODEL: str
|
|
IMAGE_SIZE: Optional[str]
|
|
IMAGE_STEPS: Optional[int]
|
|
|
|
IMAGES_OPENAI_API_BASE_URL: str
|
|
IMAGES_OPENAI_API_KEY: str
|
|
IMAGES_OPENAI_API_VERSION: str
|
|
|
|
AUTOMATIC1111_BASE_URL: str
|
|
AUTOMATIC1111_API_AUTH: str
|
|
AUTOMATIC1111_PARAMS: Optional[dict | str]
|
|
|
|
COMFYUI_BASE_URL: str
|
|
COMFYUI_API_KEY: str
|
|
COMFYUI_WORKFLOW: str
|
|
COMFYUI_WORKFLOW_NODES: list[dict]
|
|
|
|
IMAGES_GEMINI_API_BASE_URL: str
|
|
IMAGES_GEMINI_API_KEY: str
|
|
IMAGES_GEMINI_ENDPOINT_METHOD: str
|
|
|
|
IMAGE_EDIT_ENGINE: str
|
|
IMAGE_EDIT_MODEL: str
|
|
IMAGE_EDIT_SIZE: Optional[str]
|
|
|
|
IMAGES_EDIT_OPENAI_API_BASE_URL: str
|
|
IMAGES_EDIT_OPENAI_API_KEY: str
|
|
IMAGES_EDIT_OPENAI_API_VERSION: str
|
|
IMAGES_EDIT_GEMINI_API_BASE_URL: str
|
|
IMAGES_EDIT_GEMINI_API_KEY: str
|
|
|
|
|
|
@router.get("/config", response_model=ImagesConfig)
|
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
return {
|
|
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
|
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
|
|
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
|
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
|
|
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
|
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
|
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
|
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
|
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
|
|
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
|
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
|
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
|
|
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
|
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
|
|
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
|
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
|
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
|
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
|
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
|
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
|
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
|
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
|
|
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
|
|
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
|
|
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
|
|
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
|
|
}
|
|
|
|
|
|
@router.post("/config/update")
|
|
async def update_config(
|
|
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
|
|
):
|
|
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
|
|
|
|
# Create Image
|
|
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
|
|
form_data.ENABLE_IMAGE_PROMPT_GENERATION
|
|
)
|
|
|
|
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
|
|
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
|
|
if (
|
|
form_data.IMAGE_SIZE == "auto"
|
|
and form_data.IMAGE_GENERATION_MODEL != "gpt-image-1"
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.INCORRECT_FORMAT(
|
|
" (auto is only allowed with gpt-image-1)."
|
|
),
|
|
)
|
|
|
|
pattern = r"^\d+x\d+$"
|
|
if (
|
|
form_data.IMAGE_SIZE == "auto"
|
|
or form_data.IMAGE_SIZE == ""
|
|
or re.match(pattern, form_data.IMAGE_SIZE)
|
|
):
|
|
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
|
|
)
|
|
|
|
if form_data.IMAGE_STEPS >= 0:
|
|
request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
|
|
)
|
|
|
|
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
|
|
form_data.IMAGES_OPENAI_API_BASE_URL
|
|
)
|
|
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.IMAGES_OPENAI_API_KEY
|
|
request.app.state.config.IMAGES_OPENAI_API_VERSION = (
|
|
form_data.IMAGES_OPENAI_API_VERSION
|
|
)
|
|
|
|
request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL
|
|
request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
|
|
request.app.state.config.AUTOMATIC1111_PARAMS = form_data.AUTOMATIC1111_PARAMS
|
|
|
|
request.app.state.config.COMFYUI_BASE_URL = form_data.COMFYUI_BASE_URL.strip("/")
|
|
request.app.state.config.COMFYUI_API_KEY = form_data.COMFYUI_API_KEY
|
|
request.app.state.config.COMFYUI_WORKFLOW = form_data.COMFYUI_WORKFLOW
|
|
request.app.state.config.COMFYUI_WORKFLOW_NODES = form_data.COMFYUI_WORKFLOW_NODES
|
|
|
|
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
|
form_data.IMAGES_GEMINI_API_BASE_URL
|
|
)
|
|
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.IMAGES_GEMINI_API_KEY
|
|
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = (
|
|
form_data.IMAGES_GEMINI_ENDPOINT_METHOD
|
|
)
|
|
|
|
# Edit Image
|
|
request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
|
|
request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
|
|
request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
|
|
|
|
request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
|
|
form_data.IMAGES_OPENAI_API_BASE_URL
|
|
)
|
|
request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
|
|
form_data.IMAGES_OPENAI_API_KEY
|
|
)
|
|
request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
|
|
form_data.IMAGES_EDIT_OPENAI_API_VERSION
|
|
)
|
|
|
|
request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = (
|
|
form_data.IMAGES_EDIT_GEMINI_API_BASE_URL
|
|
)
|
|
request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = (
|
|
form_data.IMAGES_EDIT_GEMINI_API_KEY
|
|
)
|
|
|
|
return {
|
|
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
|
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
|
|
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
|
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
|
|
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
|
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
|
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
|
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
|
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
|
|
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
|
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
|
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
|
|
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
|
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
|
|
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
|
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
|
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
|
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
|
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
|
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
|
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
|
"IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
|
|
"IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
|
|
"IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
|
|
"IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
|
|
"IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
|
|
}
|
|
|
|
|
|
def get_automatic1111_api_auth(request: Request):
|
|
if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
|
|
return ""
|
|
else:
|
|
auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
|
|
"utf-8"
|
|
)
|
|
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
|
|
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
|
|
return f"Basic {auth1111_base64_encoded_string}"
|
|
|
|
|
|
@router.get("/config/url/verify")
|
|
async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|
if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
|
|
try:
|
|
r = requests.get(
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
headers={"authorization": get_automatic1111_api_auth(request)},
|
|
)
|
|
r.raise_for_status()
|
|
return True
|
|
except Exception:
|
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
headers = None
|
|
if request.app.state.config.COMFYUI_API_KEY:
|
|
headers = {
|
|
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
|
}
|
|
try:
|
|
r = requests.get(
|
|
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
|
headers=headers,
|
|
)
|
|
r.raise_for_status()
|
|
return True
|
|
except Exception:
|
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
|
else:
|
|
return True
|
|
|
|
|
|
@router.get("/models")
|
|
def get_models(request: Request, user=Depends(get_verified_user)):
|
|
try:
|
|
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
|
return [
|
|
{"id": "dall-e-2", "name": "DALL·E 2"},
|
|
{"id": "dall-e-3", "name": "DALL·E 3"},
|
|
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
|
|
]
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
|
return [
|
|
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
|
|
]
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
# TODO - get models from comfyui
|
|
headers = {
|
|
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
|
}
|
|
r = requests.get(
|
|
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
|
headers=headers,
|
|
)
|
|
info = r.json()
|
|
|
|
workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
|
|
model_node_id = None
|
|
|
|
for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
|
|
if node["type"] == "model":
|
|
if node["node_ids"]:
|
|
model_node_id = node["node_ids"][0]
|
|
break
|
|
|
|
if model_node_id:
|
|
model_list_key = None
|
|
|
|
log.info(workflow[model_node_id]["class_type"])
|
|
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
|
"required"
|
|
]:
|
|
if "_name" in key:
|
|
model_list_key = key
|
|
break
|
|
|
|
if model_list_key:
|
|
return list(
|
|
map(
|
|
lambda model: {"id": model, "name": model},
|
|
info[workflow[model_node_id]["class_type"]]["input"][
|
|
"required"
|
|
][model_list_key][0],
|
|
)
|
|
)
|
|
else:
|
|
return list(
|
|
map(
|
|
lambda model: {"id": model, "name": model},
|
|
info["CheckpointLoaderSimple"]["input"]["required"][
|
|
"ckpt_name"
|
|
][0],
|
|
)
|
|
)
|
|
elif (
|
|
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
|
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
|
):
|
|
r = requests.get(
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
|
headers={"authorization": get_automatic1111_api_auth(request)},
|
|
)
|
|
models = r.json()
|
|
return list(
|
|
map(
|
|
lambda model: {"id": model["title"], "name": model["model_name"]},
|
|
models,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
|
|
|
|
|
class CreateImageForm(BaseModel):
|
|
model: Optional[str] = None
|
|
prompt: str
|
|
size: Optional[str] = None
|
|
n: int = 1
|
|
negative_prompt: Optional[str] = None
|
|
|
|
|
|
GenerateImageForm = CreateImageForm # Alias for backward compatibility
|
|
|
|
|
|
def get_image_data(data: str, headers=None):
|
|
try:
|
|
if data.startswith("http://") or data.startswith("https://"):
|
|
if headers:
|
|
r = requests.get(data, headers=headers)
|
|
else:
|
|
r = requests.get(data)
|
|
|
|
r.raise_for_status()
|
|
if r.headers["content-type"].split("/")[0] == "image":
|
|
mime_type = r.headers["content-type"]
|
|
return r.content, mime_type
|
|
else:
|
|
log.error("Url does not point to an image.")
|
|
return None
|
|
else:
|
|
if "," in data:
|
|
header, encoded = data.split(",", 1)
|
|
mime_type = header.split(";")[0].lstrip("data:")
|
|
img_data = base64.b64decode(encoded)
|
|
else:
|
|
mime_type = "image/png"
|
|
img_data = base64.b64decode(data)
|
|
return img_data, mime_type
|
|
except Exception as e:
|
|
log.exception(f"Error loading image data: {e}")
|
|
return None, None
|
|
|
|
|
|
def upload_image(request, image_data, content_type, metadata, user):
|
|
image_format = mimetypes.guess_extension(content_type)
|
|
file = UploadFile(
|
|
file=io.BytesIO(image_data),
|
|
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
|
|
headers={
|
|
"content-type": content_type,
|
|
},
|
|
)
|
|
file_item = upload_file_handler(
|
|
request,
|
|
file=file,
|
|
metadata=metadata,
|
|
process=False,
|
|
user=user,
|
|
)
|
|
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
|
return url
|
|
|
|
|
|
@router.post("/generations")
|
|
async def image_generations(
|
|
request: Request,
|
|
form_data: CreateImageForm,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
# if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
|
|
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
|
|
# image model other than gpt-image-1, which is warned about on settings save
|
|
|
|
size = "512x512"
|
|
if (
|
|
request.app.state.config.IMAGE_SIZE
|
|
and "x" in request.app.state.config.IMAGE_SIZE
|
|
):
|
|
size = request.app.state.config.IMAGE_SIZE
|
|
|
|
if form_data.size and "x" in form_data.size:
|
|
size = form_data.size
|
|
|
|
width, height = tuple(map(int, size.split("x")))
|
|
model = get_image_model(request)
|
|
|
|
r = None
|
|
try:
|
|
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
|
headers = include_user_info_headers(headers, user)
|
|
|
|
data = {
|
|
"model": model,
|
|
"prompt": form_data.prompt,
|
|
"n": form_data.n,
|
|
"size": (
|
|
form_data.size
|
|
if form_data.size
|
|
else request.app.state.config.IMAGE_SIZE
|
|
),
|
|
**(
|
|
{}
|
|
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
|
|
else {"response_format": "b64_json"}
|
|
),
|
|
}
|
|
|
|
api_version_query_param = ""
|
|
if request.app.state.config.IMAGES_OPENAI_API_VERSION:
|
|
api_version_query_param = (
|
|
f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
|
|
)
|
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
r = await asyncio.to_thread(
|
|
requests.post,
|
|
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}",
|
|
json=data,
|
|
headers=headers,
|
|
)
|
|
|
|
r.raise_for_status()
|
|
res = r.json()
|
|
|
|
images = []
|
|
|
|
for image in res["data"]:
|
|
if image_url := image.get("url", None):
|
|
image_data, content_type = get_image_data(image_url, headers)
|
|
else:
|
|
image_data, content_type = get_image_data(image["b64_json"])
|
|
|
|
url = upload_image(request, image_data, content_type, data, user)
|
|
images.append({"url": url})
|
|
return images
|
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
|
}
|
|
|
|
data = {}
|
|
|
|
if (
|
|
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == ""
|
|
or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "predict"
|
|
):
|
|
model = f"{model}:predict"
|
|
data = {
|
|
"instances": {"prompt": form_data.prompt},
|
|
"parameters": {
|
|
"sampleCount": form_data.n,
|
|
"outputOptions": {"mimeType": "image/png"},
|
|
},
|
|
}
|
|
|
|
elif (
|
|
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD
|
|
== "generateContent"
|
|
):
|
|
model = f"{model}:generateContent"
|
|
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
|
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
r = await asyncio.to_thread(
|
|
requests.post,
|
|
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
|
|
json=data,
|
|
headers=headers,
|
|
)
|
|
|
|
r.raise_for_status()
|
|
res = r.json()
|
|
|
|
images = []
|
|
|
|
if model.endswith(":predict"):
|
|
for image in res["predictions"]:
|
|
image_data, content_type = get_image_data(
|
|
image["bytesBase64Encoded"]
|
|
)
|
|
url = upload_image(request, image_data, content_type, data, user)
|
|
images.append({"url": url})
|
|
elif model.endswith(":generateContent"):
|
|
for image in res["candidates"]:
|
|
for part in image["content"]["parts"]:
|
|
if part.get("inlineData", {}).get("data"):
|
|
image_data, content_type = get_image_data(
|
|
part["inlineData"]["data"]
|
|
)
|
|
url = upload_image(
|
|
request, image_data, content_type, data, user
|
|
)
|
|
images.append({"url": url})
|
|
|
|
return images
|
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
data = {
|
|
"prompt": form_data.prompt,
|
|
"width": width,
|
|
"height": height,
|
|
"n": form_data.n,
|
|
}
|
|
|
|
if request.app.state.config.IMAGE_STEPS is not None:
|
|
data["steps"] = request.app.state.config.IMAGE_STEPS
|
|
|
|
if form_data.negative_prompt is not None:
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
|
|
form_data = ComfyUICreateImageForm(
|
|
**{
|
|
"workflow": ComfyUIWorkflow(
|
|
**{
|
|
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
|
|
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
}
|
|
),
|
|
**data,
|
|
}
|
|
)
|
|
res = await comfyui_create_image(
|
|
model,
|
|
form_data,
|
|
user.id,
|
|
request.app.state.config.COMFYUI_BASE_URL,
|
|
request.app.state.config.COMFYUI_API_KEY,
|
|
)
|
|
log.debug(f"res: {res}")
|
|
|
|
images = []
|
|
|
|
for image in res["data"]:
|
|
headers = None
|
|
if request.app.state.config.COMFYUI_API_KEY:
|
|
headers = {
|
|
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
|
}
|
|
|
|
image_data, content_type = get_image_data(image["url"], headers)
|
|
url = upload_image(
|
|
request,
|
|
image_data,
|
|
content_type,
|
|
form_data.model_dump(exclude_none=True),
|
|
user,
|
|
)
|
|
images.append({"url": url})
|
|
return images
|
|
elif (
|
|
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
|
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
|
):
|
|
if form_data.model:
|
|
set_image_model(request, form_data.model)
|
|
|
|
data = {
|
|
"prompt": form_data.prompt,
|
|
"batch_size": form_data.n,
|
|
"width": width,
|
|
"height": height,
|
|
}
|
|
|
|
if request.app.state.config.IMAGE_STEPS is not None:
|
|
data["steps"] = request.app.state.config.IMAGE_STEPS
|
|
|
|
if form_data.negative_prompt is not None:
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
|
|
if request.app.state.config.AUTOMATIC1111_PARAMS:
|
|
data = {**data, **request.app.state.config.AUTOMATIC1111_PARAMS}
|
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
r = await asyncio.to_thread(
|
|
requests.post,
|
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
|
json=data,
|
|
headers={"authorization": get_automatic1111_api_auth(request)},
|
|
)
|
|
|
|
res = r.json()
|
|
log.debug(f"res: {res}")
|
|
|
|
images = []
|
|
|
|
for image in res["images"]:
|
|
image_data, content_type = get_image_data(image)
|
|
url = upload_image(
|
|
request,
|
|
image_data,
|
|
content_type,
|
|
{**data, "info": res["info"]},
|
|
user,
|
|
)
|
|
images.append({"url": url})
|
|
return images
|
|
except Exception as e:
|
|
error = e
|
|
if r != None:
|
|
data = r.json()
|
|
if "error" in data:
|
|
error = data["error"]["message"]
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
|
|
|
|
|
|
class EditImageForm(BaseModel):
|
|
image: str | list[str] # base64-encoded image(s) or URL(s)
|
|
prompt: str
|
|
model: Optional[str] = None
|
|
size: Optional[str] = None
|
|
n: Optional[int] = None
|
|
negative_prompt: Optional[str] = None
|
|
|
|
|
|
@router.post("/edit")
|
|
async def image_edits(
|
|
request: Request,
|
|
form_data: EditImageForm,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
|
|
size = None
|
|
width, height = None, None
|
|
if (
|
|
request.app.state.config.IMAGE_EDIT_SIZE
|
|
and "x" in request.app.state.config.IMAGE_EDIT_SIZE
|
|
) or (form_data.size and "x" in form_data.size):
|
|
size = (
|
|
form_data.size
|
|
if form_data.size
|
|
else request.app.state.config.IMAGE_EDIT_SIZE
|
|
)
|
|
width, height = tuple(map(int, size.split("x")))
|
|
|
|
model = (
|
|
request.app.state.config.IMAGE_EDIT_MODEL
|
|
if form_data.model is None
|
|
else form_data.model
|
|
)
|
|
|
|
def load_url_image(string):
|
|
if string.startswith("http://") or string.startswith("https://"):
|
|
r = requests.get(string)
|
|
r.raise_for_status()
|
|
image_data = base64.b64encode(r.content).decode("utf-8")
|
|
return f"data:{r.headers['content-type']};base64,{image_data}"
|
|
|
|
elif string.startswith("/api/v1/files"):
|
|
file_id = string.split("/api/v1/files/")[1].split("/content")[0]
|
|
file_response = get_file_content_by_id(file_id, user)
|
|
|
|
if isinstance(file_response, FileResponse):
|
|
file_bytes = file_response.body
|
|
mime_type = file_response.headers.get("content-type", "image/png")
|
|
image_data = base64.b64encode(file_bytes).decode("utf-8")
|
|
return f"data:{mime_type};base64,{image_data}"
|
|
return string
|
|
|
|
# Load image(s) from URL(s) if necessary
|
|
if isinstance(form_data.image, str):
|
|
form_data.image = load_url_image(form_data.image)
|
|
elif isinstance(form_data.image, list):
|
|
form_data.image = [load_url_image(img) for img in form_data.image]
|
|
|
|
r = None
|
|
try:
|
|
if request.app.state.config.IMAGE_EDIT_ENGINE == "openai":
|
|
headers = {
|
|
"Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}",
|
|
}
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
|
headers = include_user_info_headers(headers, user)
|
|
|
|
data = {
|
|
"model": model,
|
|
"prompt": form_data.prompt,
|
|
**({"n": form_data.n} if form_data.n else {}),
|
|
**({"size": size} if size else {}),
|
|
**(
|
|
{}
|
|
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
|
|
else {"response_format": "b64_json"}
|
|
),
|
|
}
|
|
|
|
def get_image_file_item(base64_string):
|
|
data = base64_string
|
|
header, encoded = data.split(",", 1)
|
|
mime_type = header.split(";")[0].lstrip("data:")
|
|
image_data = base64.b64decode(encoded)
|
|
return (
|
|
"image",
|
|
(
|
|
f"{uuid.uuid4()}.png",
|
|
io.BytesIO(image_data),
|
|
mime_type if mime_type else "image/png",
|
|
),
|
|
)
|
|
|
|
files = []
|
|
if isinstance(form_data.image, str):
|
|
files = [get_image_file_item(form_data.image)]
|
|
elif isinstance(form_data.image, list):
|
|
for img in form_data.image:
|
|
files.append(get_image_file_item(img))
|
|
|
|
url_search_params = ""
|
|
if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
|
|
url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}"
|
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
r = await asyncio.to_thread(
|
|
requests.post,
|
|
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/edits{url_search_params}",
|
|
headers=headers,
|
|
files=files,
|
|
data=data,
|
|
)
|
|
|
|
r.raise_for_status()
|
|
res = r.json()
|
|
|
|
images = []
|
|
for image in res["data"]:
|
|
if image_url := image.get("url", None):
|
|
image_data, content_type = get_image_data(image_url, headers)
|
|
else:
|
|
image_data, content_type = get_image_data(image["b64_json"])
|
|
|
|
url = upload_image(request, image_data, content_type, data, user)
|
|
images.append({"url": url})
|
|
return images
|
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
|
}
|
|
|
|
model = f"{model}:generateContent"
|
|
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
|
|
|
|
if isinstance(form_data.image, str):
|
|
data["contents"][0]["parts"].append(
|
|
{
|
|
"inline_data": {
|
|
"mime_type": "image/png",
|
|
"data": form_data.image.split(",", 1)[1],
|
|
}
|
|
}
|
|
)
|
|
elif isinstance(form_data.image, list):
|
|
data["contents"][0]["parts"].extend(
|
|
[
|
|
{
|
|
"inline_data": {
|
|
"mime_type": "image/png",
|
|
"data": image.split(",", 1)[1],
|
|
}
|
|
}
|
|
for image in form_data.image
|
|
]
|
|
)
|
|
|
|
# Use asyncio.to_thread for the requests.post call
|
|
r = await asyncio.to_thread(
|
|
requests.post,
|
|
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
|
|
json=data,
|
|
headers=headers,
|
|
)
|
|
|
|
r.raise_for_status()
|
|
res = r.json()
|
|
|
|
images = []
|
|
for image in res["candidates"]:
|
|
for part in image["content"]["parts"]:
|
|
if part.get("inlineData", {}).get("data"):
|
|
image_data, content_type = get_image_data(
|
|
part["inlineData"]["data"]
|
|
)
|
|
url = upload_image(
|
|
request, image_data, content_type, data, user
|
|
)
|
|
images.append({"url": url})
|
|
|
|
return images
|
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
data = {
|
|
"prompt": form_data.prompt,
|
|
"width": width,
|
|
"height": height,
|
|
"n": form_data.n,
|
|
}
|
|
|
|
if request.app.state.config.IMAGE_EDIT_STEPS is not None:
|
|
data["steps"] = request.app.state.config.IMAGE_EDIT_STEPS
|
|
|
|
if form_data.negative_prompt is not None:
|
|
data["negative_prompt"] = form_data.negative_prompt
|
|
|
|
form_data = ComfyUICreateImageForm(
|
|
**{
|
|
"workflow": ComfyUIWorkflow(
|
|
**{
|
|
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
|
|
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
}
|
|
),
|
|
**data,
|
|
}
|
|
)
|
|
res = await comfyui_create_image(
|
|
model,
|
|
form_data,
|
|
user.id,
|
|
request.app.state.config.COMFYUI_BASE_URL,
|
|
request.app.state.config.COMFYUI_API_KEY,
|
|
)
|
|
log.debug(f"res: {res}")
|
|
|
|
images = []
|
|
|
|
for image in res["data"]:
|
|
headers = None
|
|
if request.app.state.config.COMFYUI_API_KEY:
|
|
headers = {
|
|
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
|
}
|
|
|
|
image_data, content_type = get_image_data(image["url"], headers)
|
|
url = upload_image(
|
|
request,
|
|
image_data,
|
|
content_type,
|
|
form_data.model_dump(exclude_none=True),
|
|
user,
|
|
)
|
|
images.append({"url": url})
|
|
return images
|
|
except Exception as e:
|
|
error = e
|
|
if r != None:
|
|
data = r.text
|
|
try:
|
|
data = json.loads(data)
|
|
if "error" in data:
|
|
error = data["error"]["message"]
|
|
except Exception:
|
|
error = data
|
|
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
|