diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index ea8d36aa9a..3a5b800566 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -3074,16 +3074,30 @@ EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig( # Images #################################### +ENABLE_IMAGE_GENERATION = PersistentConfig( + "ENABLE_IMAGE_GENERATION", + "image_generation.enable", + os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", +) + IMAGE_GENERATION_ENGINE = PersistentConfig( "IMAGE_GENERATION_ENGINE", "image_generation.engine", os.getenv("IMAGE_GENERATION_ENGINE", "openai"), ) -ENABLE_IMAGE_GENERATION = PersistentConfig( - "ENABLE_IMAGE_GENERATION", - "image_generation.enable", - os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", +IMAGE_GENERATION_MODEL = PersistentConfig( + "IMAGE_GENERATION_MODEL", + "image_generation.model", + os.getenv("IMAGE_GENERATION_MODEL", ""), +) + +IMAGE_SIZE = PersistentConfig( + "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") +) + +IMAGE_STEPS = PersistentConfig( + "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) ) ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig( @@ -3285,20 +3299,52 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig( os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""), ) -IMAGE_SIZE = PersistentConfig( - "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") + +IMAGE_EDIT_ENGINE = PersistentConfig( + "IMAGE_EDIT_ENGINE", + "images.edit.engine", + os.getenv("IMAGE_EDIT_ENGINE", "openai"), ) -IMAGE_STEPS = PersistentConfig( - "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) +IMAGE_EDIT_MODEL = PersistentConfig( + "IMAGE_EDIT_MODEL", + "images.edit.model", + os.getenv("IMAGE_EDIT_MODEL", ""), ) -IMAGE_GENERATION_MODEL = PersistentConfig( - "IMAGE_GENERATION_MODEL", - "image_generation.model", - os.getenv("IMAGE_GENERATION_MODEL", ""), +IMAGE_EDIT_SIZE = PersistentConfig( + "IMAGE_EDIT_SIZE", "images.edit.size", os.getenv("IMAGE_EDIT_SIZE", "") ) +IMAGES_EDIT_OPENAI_API_BASE_URL = PersistentConfig( + "IMAGES_EDIT_OPENAI_API_BASE_URL", + "images.edit.openai.api_base_url", + os.getenv("IMAGES_EDIT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +IMAGES_EDIT_OPENAI_API_VERSION = PersistentConfig( + "IMAGES_EDIT_OPENAI_API_VERSION", + "images.edit.openai.api_version", + os.getenv("IMAGES_EDIT_OPENAI_API_VERSION", ""), +) + +IMAGES_EDIT_OPENAI_API_KEY = PersistentConfig( + "IMAGES_EDIT_OPENAI_API_KEY", + "images.edit.openai.api_key", + os.getenv("IMAGES_EDIT_OPENAI_API_KEY", OPENAI_API_KEY), +) + +IMAGES_EDIT_GEMINI_API_BASE_URL = PersistentConfig( + "IMAGES_EDIT_GEMINI_API_BASE_URL", + "images.edit.gemini.api_base_url", + os.getenv("IMAGES_EDIT_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL), +) +IMAGES_EDIT_GEMINI_API_KEY = PersistentConfig( + "IMAGES_EDIT_GEMINI_API_KEY", + "images.edit.gemini.api_key", + os.getenv("IMAGES_EDIT_GEMINI_API_KEY", GEMINI_API_KEY), +) + + #################################### # Audio #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 0512bff1eb..44f0560925 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -163,6 +163,14 @@ from open_webui.config import ( IMAGES_GEMINI_API_BASE_URL, IMAGES_GEMINI_API_KEY, IMAGES_GEMINI_ENDPOINT_METHOD, + IMAGE_EDIT_ENGINE, + IMAGE_EDIT_MODEL, + IMAGE_EDIT_SIZE, + IMAGES_EDIT_OPENAI_API_BASE_URL, + IMAGES_EDIT_OPENAI_API_KEY, + IMAGES_EDIT_OPENAI_API_VERSION, + IMAGES_EDIT_GEMINI_API_BASE_URL, + IMAGES_EDIT_GEMINI_API_KEY, # Audio AUDIO_STT_ENGINE, AUDIO_STT_MODEL, @@ -1078,7 +1086,6 @@ app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD - app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS @@ -1089,6 +1096,16 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES +app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE +app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL +app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE +app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = IMAGES_EDIT_OPENAI_API_BASE_URL +app.state.config.IMAGES_EDIT_OPENAI_API_KEY = IMAGES_EDIT_OPENAI_API_KEY +app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = IMAGES_EDIT_OPENAI_API_VERSION +app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = IMAGES_EDIT_GEMINI_API_BASE_URL +app.state.config.IMAGES_EDIT_GEMINI_API_KEY = IMAGES_EDIT_GEMINI_API_KEY + + ######################################## # # AUDIO diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index dd7b02550a..469cd448f8 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -1,5 +1,6 @@ import asyncio import base64 +import uuid import io import json import logging @@ -10,18 +11,13 @@ from typing import Optional from urllib.parse import quote import requests -from fastapi import ( - APIRouter, - Depends, - HTTPException, - Request, - UploadFile, -) +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile +from fastapi.responses import FileResponse from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS -from open_webui.routers.files import upload_file_handler +from open_webui.routers.files import upload_file_handler, get_file_content_by_id from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.headers import include_user_info_headers from open_webui.utils.images.comfyui import ( @@ -121,6 +117,16 @@ class ImagesConfig(BaseModel): IMAGES_GEMINI_API_KEY: str IMAGES_GEMINI_ENDPOINT_METHOD: str + IMAGE_EDIT_ENGINE: str + IMAGE_EDIT_MODEL: str + IMAGE_EDIT_SIZE: Optional[str] + + IMAGES_EDIT_OPENAI_API_BASE_URL: str + IMAGES_EDIT_OPENAI_API_KEY: str + IMAGES_EDIT_OPENAI_API_VERSION: str + IMAGES_EDIT_GEMINI_API_BASE_URL: str + IMAGES_EDIT_GEMINI_API_KEY: str + @router.get("/config", response_model=ImagesConfig) async def get_config(request: Request, user=Depends(get_admin_user)): @@ -144,6 +150,14 @@ async def get_config(request: Request, user=Depends(get_admin_user)): "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, + "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, + "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, + "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, + "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL, + "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY, + "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION, + "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL, + "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, } @@ -152,6 +166,8 @@ async def update_config( request: Request, form_data: ImagesConfig, user=Depends(get_admin_user) ): request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION + + # Create Image request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ( form_data.ENABLE_IMAGE_PROMPT_GENERATION ) @@ -215,6 +231,28 @@ async def update_config( form_data.IMAGES_GEMINI_ENDPOINT_METHOD ) + # Edit Image + request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE + request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL + request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE + + request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = ( + form_data.IMAGES_OPENAI_API_BASE_URL + ) + request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = ( + form_data.IMAGES_OPENAI_API_KEY + ) + request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = ( + form_data.IMAGES_EDIT_OPENAI_API_VERSION + ) + + request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = ( + form_data.IMAGES_EDIT_GEMINI_API_BASE_URL + ) + request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = ( + form_data.IMAGES_EDIT_GEMINI_API_KEY + ) + return { "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION, "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, @@ -235,6 +273,14 @@ async def update_config( "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, + "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, + "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, + "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, + "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL, + "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY, + "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION, + "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL, + "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY, } @@ -674,3 +720,255 @@ async def image_generations( if "error" in data: error = data["error"]["message"] raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error)) + + +class EditImageForm(BaseModel): + image: str | list[str] # base64-encoded image(s) or URL(s) + prompt: str + model: Optional[str] = None + size: Optional[str] = None + n: Optional[int] = None + negative_prompt: Optional[str] = None + + +@router.post("/edit") +async def image_edits( + request: Request, + form_data: EditImageForm, + user=Depends(get_verified_user), +): + + size = None + width, height = None, None + if ( + request.app.state.config.IMAGE_EDIT_SIZE + and "x" in request.app.state.config.IMAGE_EDIT_SIZE + ) or (form_data.size and "x" in form_data.size): + size = ( + form_data.size + if form_data.size + else request.app.state.config.IMAGE_EDIT_SIZE + ) + width, height = tuple(map(int, size.split("x"))) + + model = ( + request.app.state.config.IMAGE_EDIT_MODEL + if form_data.model is None + else form_data.model + ) + + def load_url_image(string): + if string.startswith("http://") or string.startswith("https://"): + r = requests.get(string) + r.raise_for_status() + image_data = base64.b64encode(r.content).decode("utf-8") + return f"data:{r.headers['content-type']};base64,{image_data}" + + elif string.startswith("/api/v1/files"): + file_id = string.split("/api/v1/files/")[1].split("/content")[0] + file_response = get_file_content_by_id(file_id, user) + + if isinstance(file_response, FileResponse): + file_bytes = file_response.body + mime_type = file_response.headers.get("content-type", "image/png") + image_data = base64.b64encode(file_bytes).decode("utf-8") + return f"data:{mime_type};base64,{image_data}" + return string + + # Load image(s) from URL(s) if necessary + if isinstance(form_data.image, str): + form_data.image = load_url_image(form_data.image) + elif isinstance(form_data.image, list): + form_data.image = [load_url_image(img) for img in form_data.image] + + r = None + try: + if request.app.state.config.IMAGE_EDIT_ENGINE == "openai": + headers = { + "Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}", + } + + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers = include_user_info_headers(headers, user) + + data = { + "model": model, + "prompt": form_data.prompt, + **({"n": form_data.n} if form_data.n else {}), + **({"size": size} if size else {}), + **( + {} + if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL + else {"response_format": "b64_json"} + ), + } + + def get_image_file_item(base64_string): + data = base64_string + header, encoded = data.split(",", 1) + mime_type = header.split(";")[0].lstrip("data:") + image_data = base64.b64decode(encoded) + return ( + "image", + ( + f"{uuid.uuid4()}.png", + io.BytesIO(image_data), + mime_type if mime_type else "image/png", + ), + ) + + files = [] + if isinstance(form_data.image, str): + files = [get_image_file_item(form_data.image)] + elif isinstance(form_data.image, list): + for img in form_data.image: + files.append(get_image_file_item(img)) + + url_search_params = "" + if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION: + url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}" + + # Use asyncio.to_thread for the requests.post call + r = await asyncio.to_thread( + requests.post, + url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/edits{url_search_params}", + headers=headers, + files=files, + data=data, + ) + + r.raise_for_status() + res = r.json() + + images = [] + for image in res["data"]: + if image_url := image.get("url", None): + image_data, content_type = get_image_data(image_url, headers) + else: + image_data, content_type = get_image_data(image["b64_json"]) + + url = upload_image(request, image_data, content_type, data, user) + images.append({"url": url}) + return images + + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + headers = { + "Content-Type": "application/json", + "x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY, + } + + model = f"{model}:generateContent" + data = {"contents": [{"parts": [{"text": form_data.prompt}]}]} + + if isinstance(form_data.image, str): + data["contents"][0]["parts"].append( + { + "inline_data": { + "mime_type": "image/png", + "data": form_data.image.split(",", 1)[1], + } + } + ) + elif isinstance(form_data.image, list): + data["contents"][0]["parts"].extend( + [ + { + "inline_data": { + "mime_type": "image/png", + "data": image.split(",", 1)[1], + } + } + for image in form_data.image + ] + ) + + # Use asyncio.to_thread for the requests.post call + r = await asyncio.to_thread( + requests.post, + url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}", + json=data, + headers=headers, + ) + + r.raise_for_status() + res = r.json() + + images = [] + for image in res["candidates"]: + for part in image["content"]["parts"]: + if part.get("inlineData", {}).get("data"): + image_data, content_type = get_image_data( + part["inlineData"]["data"] + ) + url = upload_image( + request, image_data, content_type, data, user + ) + images.append({"url": url}) + + return images + + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + data = { + "prompt": form_data.prompt, + "width": width, + "height": height, + "n": form_data.n, + } + + if request.app.state.config.IMAGE_EDIT_STEPS is not None: + data["steps"] = request.app.state.config.IMAGE_EDIT_STEPS + + if form_data.negative_prompt is not None: + data["negative_prompt"] = form_data.negative_prompt + + form_data = ComfyUICreateImageForm( + **{ + "workflow": ComfyUIWorkflow( + **{ + "workflow": request.app.state.config.COMFYUI_WORKFLOW, + "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES, + } + ), + **data, + } + ) + res = await comfyui_create_image( + model, + form_data, + user.id, + request.app.state.config.COMFYUI_BASE_URL, + request.app.state.config.COMFYUI_API_KEY, + ) + log.debug(f"res: {res}") + + images = [] + + for image in res["data"]: + headers = None + if request.app.state.config.COMFYUI_API_KEY: + headers = { + "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" + } + + image_data, content_type = get_image_data(image["url"], headers) + url = upload_image( + request, + image_data, + content_type, + form_data.model_dump(exclude_none=True), + user, + ) + images.append({"url": url}) + return images + except Exception as e: + error = e + if r != None: + data = r.text + try: + data = json.loads(data) + if "error" in data: + error = data["error"]["message"] + except Exception: + error = data + + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error)) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e4cace689a..e3cad01222 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -47,7 +47,8 @@ from open_webui.routers.retrieval import ( from open_webui.routers.images import ( image_generations, CreateImageForm, - upload_image, + image_edits, + EditImageForm, ) from open_webui.routers.pipelines import ( process_pipeline_inlet_filter, @@ -717,9 +718,31 @@ async def chat_web_search_handler( return form_data +def get_last_images(message_list): + images = [] + for message in reversed(message_list): + images_flag = False + for file in message.get("files", []): + if file.get("type") == "image": + images.append(file.get("url")) + images_flag = True + + if images_flag: + break + + return images + + async def chat_image_generation_handler( request: Request, form_data: dict, extra_params: dict, user ): + metadata = extra_params.get("__metadata__", {}) + chat_id = metadata.get("chat_id", None) + if not chat_id: + return form_data + + chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id) + __event_emitter__ = extra_params["__event_emitter__"] await __event_emitter__( { @@ -728,87 +751,151 @@ async def chat_image_generation_handler( } ) - messages = form_data["messages"] - user_message = get_last_user_message(messages) + messages_map = chat.chat.get("history", {}).get("messages", {}) + message_id = chat.chat.get("history", {}).get("currentId") + message_list = get_message_list(messages_map, message_id) + user_message = get_last_user_message(message_list) prompt = user_message - negative_prompt = "" - - if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: - try: - res = await generate_image_prompt( - request, - { - "model": form_data["model"], - "messages": messages, - }, - user, - ) - - response = res["choices"][0]["message"]["content"] - - try: - bracket_start = response.find("{") - bracket_end = response.rfind("}") + 1 - - if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") - - response = response[bracket_start:bracket_end] - response = json.loads(response) - prompt = response.get("prompt", []) - except Exception as e: - prompt = user_message - - except Exception as e: - log.exception(e) - prompt = user_message + input_images = get_last_images(message_list) system_message_content = "" + if len(input_images) == 0: + # Create image(s) + if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: + try: + res = await generate_image_prompt( + request, + { + "model": form_data["model"], + "messages": form_data["messages"], + }, + user, + ) - try: - images = await image_generations( - request=request, - form_data=CreateImageForm(**{"prompt": prompt}), - user=user, - ) + response = res["choices"][0]["message"]["content"] - await __event_emitter__( - { - "type": "status", - "data": {"description": "Image created", "done": True}, - } - ) + try: + bracket_start = response.find("{") + bracket_end = response.rfind("}") + 1 - await __event_emitter__( - { - "type": "files", - "data": { - "files": [ - { - "type": "image", - "url": image["url"], - } - for image in images - ] - }, - } - ) + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") - system_message_content = "User is shown the generated image, tell the user that the image has been generated" - except Exception as e: - log.exception(e) - await __event_emitter__( - { - "type": "status", - "data": { - "description": f"An error occurred while generating an image", - "done": True, - }, - } - ) + response = response[bracket_start:bracket_end] + response = json.loads(response) + prompt = response.get("prompt", []) + except Exception as e: + prompt = user_message - system_message_content = "Unable to generate an image, tell the user that an error occurred" + except Exception as e: + log.exception(e) + prompt = user_message + + try: + images = await image_generations( + request=request, + form_data=CreateImageForm(**{"prompt": prompt}), + user=user, + ) + + await __event_emitter__( + { + "type": "status", + "data": {"description": "Image created", "done": True}, + } + ) + + await __event_emitter__( + { + "type": "files", + "data": { + "files": [ + { + "type": "image", + "url": image["url"], + } + for image in images + ] + }, + } + ) + + system_message_content = "The requested image has been created and is now being shown to the user. Let them know that it has been generated." + except Exception as e: + log.debug(e) + + error_message = "" + if isinstance(e, HTTPException): + if e.detail and isinstance(e.detail, dict): + error_message = e.detail.get("message", str(e.detail)) + else: + error_message = str(e.detail) + + await __event_emitter__( + { + "type": "status", + "data": { + "description": f"An error occurred while generating an image", + "done": True, + }, + } + ) + + system_message_content = f"Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}" + else: + # Edit image(s) + try: + images = await image_edits( + request=request, + form_data=EditImageForm(**{"prompt": prompt, "image": input_images}), + user=user, + ) + + await __event_emitter__( + { + "type": "status", + "data": {"description": "Image created", "done": True}, + } + ) + + await __event_emitter__( + { + "type": "files", + "data": { + "files": [ + { + "type": "image", + "url": image["url"], + } + for image in images + ] + }, + } + ) + + system_message_content = "The requested image has been created and is now being shown to the user. Let them know that it has been generated." + except Exception as e: + log.debug(e) + + error_message = "" + if isinstance(e, HTTPException): + if e.detail and isinstance(e.detail, dict): + error_message = e.detail.get("message", str(e.detail)) + else: + error_message = str(e.detail) + + await __event_emitter__( + { + "type": "status", + "data": { + "description": f"An error occurred while generating an image", + "done": True, + }, + } + ) + + system_message_content = f"Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}" if system_message_content: form_data["messages"] = add_or_update_system_message( diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index 247eb85d51..75b389e067 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -29,7 +29,7 @@ let config = null; let showComfyUIWorkflowEditor = false; - let requiredWorkflowNodes = [ + let REQUIRED_WORKFLOW_NODES = [ { type: 'prompt', key: 'text', @@ -62,6 +62,29 @@ } ]; + let REQUIRED_EDIT_WORKFLOW_NODES = [ + { + type: 'prompt', + key: 'text', + node_ids: '' + }, + { + type: 'model', + key: 'ckpt_name', + node_ids: '' + }, + { + type: 'width', + key: 'width', + node_ids: '' + }, + { + type: 'height', + key: 'height', + node_ids: '' + } + ]; + const getModels = async () => { models = await getImageGenerationModels(localStorage.token).catch((error) => { toast.error(`${error}`); @@ -137,7 +160,7 @@ } if (config?.COMFYUI_WORKFLOW) { - config.COMFYUI_WORKFLOW_NODES = requiredWorkflowNodes.map((node) => { + config.COMFYUI_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => { return { type: node.type, key: node.key, @@ -178,7 +201,7 @@ } } - requiredWorkflowNodes = requiredWorkflowNodes.map((node) => { + REQUIRED_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => { const n = config.COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node; console.debug(n); @@ -665,7 +688,7 @@
- {#each requiredWorkflowNodes as node} + {#each REQUIRED_WORKFLOW_NODES as node}
@@ -791,13 +814,13 @@ placeholder={$i18n.t('Select Engine')} > - - + +
- {#if config.ENABLE_IMAGE_EDIT} + {#if config.ENABLE_IMAGE_GENERATION}
@@ -918,7 +941,7 @@
@@ -977,7 +1000,7 @@
{ - config.COMFYUI_WORKFLOW = e.target.result; + config.IMAGES_EDIT_COMFYUI_WORKFLOW = e.target.result; e.target.value = null; }; @@ -1002,7 +1025,7 @@
- {#if config.COMFYUI_WORKFLOW} + {#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
- {#if config.COMFYUI_WORKFLOW} + {#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
@@ -1067,7 +1082,7 @@
- {#each requiredWorkflowNodes as node} + {#each REQUIRED_EDIT_WORKFLOW_NODES as node}
@@ -1111,6 +1126,47 @@
{/if} + {:else if config?.IMAGE_GENERATION_ENGINE === 'gemini'} +
+
+
+
+ {$i18n.t('Gemini Base URL')} +
+
+ +
+
+ +
+
+
+
+ +
+
+
+
+ {$i18n.t('Gemini API Key')} +
+
+ +
+
+ +
+
+
+
{/if}