diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 166376305a..af0a8ed0ee 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,10 +3,10 @@ updates: - package-ecosystem: pip directory: '/backend' schedule: - interval: weekly + interval: monthly target-branch: 'dev' - package-ecosystem: 'github-actions' directory: '/' schedule: # Check for updates to GitHub Actions every week - interval: 'weekly' + interval: monthly diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b6bdd98fb..6b7bdef0b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,37 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.14] - 2024-08-21 + +### Added + +- **🛠️ Custom ComfyUI Workflow**: Deprecating several older environment variables, this enhancement introduces a new, customizable workflow for a more tailored user experience. +- **🔀 Merge Responses in Many Model Chat**: Enhances the dialogue by merging responses from multiple models into a single, coherent reply, improving the interaction quality in many model chats. +- **✅ Multiple Instances of Same Model in Chats**: Enhanced many model chat to support adding multiple instances of the same model. +- **🔧 Quick Actions in Model Workspace**: Enhanced Shift key quick actions for hiding/unhiding and deleting models, facilitating a smoother workflow. +- **🗨️ Markdown Rendering in User Messages**: User messages are now rendered in Markdown, enhancing readability and interaction. +- **💬 Temporary Chat Feature**: Introduced a temporary chat feature, deprecating the old chat history setting to enhance user interaction flexibility. +- **🖋️ User Message Editing**: Enhanced the user chat editing feature to allow saving changes without sending, providing more flexibility in message management. +- **🛡️ Security Enhancements**: Various security improvements implemented across the platform to ensure safer user experiences. +- **🌍 Updated Translations**: Enhanced translations for Chinese, Ukrainian, and Bahasa Malaysia, improving localization and user comprehension. + +### Fixed + +- **📑 Mermaid Rendering Issue**: Addressed issues with Mermaid chart rendering to ensure clean and clear visual data representation. +- **🎭 PWA Icon Maskability**: Fixed the Progressive Web App icon to be maskable, ensuring proper display on various device home screens. +- **🔀 Cloned Model Chat Freezing Issue**: Fixed a bug where cloning many model chats would cause freezing, enhancing stability and responsiveness. +- **🔍 Generic Error Handling and Refinements**: Various minor fixes and refinements to address previously untracked issues, ensuring smoother operations. + +### Changed + +- **🖼️ Image Generation Refactor**: Overhauled image generation processes for improved efficiency and quality. +- **🔨 Refactor Tool and Function Calling**: Refactored tool and function calling mechanisms for improved clarity and maintainability. +- **🌐 Backend Library Updates**: Updated critical backend libraries including SQLAlchemy, uvicorn[standard], faster-whisper, bcrypt, and boto3 for enhanced performance and security. + +### Removed + +- **🚫 Deprecated ComfyUI Environment Variables**: Removed several outdated environment variables related to ComfyUI settings, simplifying configuration management. + ## [0.3.13] - 2024-08-14 ### Added diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 20519b59b1..d66a9fa11e 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -1,5 +1,12 @@ -import os +import hashlib +import json import logging +import os +import uuid +from functools import lru_cache +from pathlib import Path + +import requests from fastapi import ( FastAPI, Request, @@ -8,34 +15,14 @@ from fastapi import ( status, UploadFile, File, - Form, ) -from fastapi.responses import StreamingResponse, JSONResponse, FileResponse - from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse from pydantic import BaseModel - -import uuid -import requests -import hashlib -from pathlib import Path -import json - -from constants import ERROR_MESSAGES -from utils.utils import ( - decode_token, - get_current_user, - get_verified_user, - get_admin_user, -) -from utils.misc import calculate_sha256 - - from config import ( SRC_LOG_LEVELS, CACHE_DIR, - UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, WHISPER_MODEL_AUTO_UPDATE, @@ -51,6 +38,13 @@ from config import ( AUDIO_TTS_MODEL, AUDIO_TTS_VOICE, AppConfig, + CORS_ALLOW_ORIGIN, +) +from constants import ERROR_MESSAGES +from utils.utils import ( + get_current_user, + get_verified_user, + get_admin_user, ) log = logging.getLogger(__name__) @@ -59,7 +53,7 @@ log.setLevel(SRC_LOG_LEVELS["AUDIO"]) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -261,6 +255,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=400, detail="Invalid JSON payload") voice_id = payload.get("voice", "") + + if voice_id not in get_available_voices(): + raise HTTPException( + status_code=400, + detail="Invalid voice id", + ) + url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" headers = { @@ -466,39 +467,58 @@ async def get_models(user=Depends(get_verified_user)): return {"models": get_available_models()} -def get_available_voices() -> list[dict]: +def get_available_voices() -> dict: + """Returns {voice_id: voice_name} dict""" + ret = {} if app.state.config.TTS_ENGINE == "openai": - return [ - {"name": "alloy", "id": "alloy"}, - {"name": "echo", "id": "echo"}, - {"name": "fable", "id": "fable"}, - {"name": "onyx", "id": "onyx"}, - {"name": "nova", "id": "nova"}, - {"name": "shimmer", "id": "shimmer"}, - ] - elif app.state.config.TTS_ENGINE == "elevenlabs": - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", + ret = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", } - + elif app.state.config.TTS_ENGINE == "elevenlabs": try: - response = requests.get( - "https://api.elevenlabs.io/v1/voices", headers=headers - ) - response.raise_for_status() - voices_data = response.json() + ret = get_elevenlabs_voices() + except Exception as e: + # Avoided @lru_cache with exception + pass - voices = [] - for voice in voices_data.get("voices", []): - voices.append({"name": voice["name"], "id": voice["voice_id"]}) - return voices - except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") + return ret - return [] + +@lru_cache +def get_elevenlabs_voices() -> dict: + """ + Note, set the following in your .env file to use Elevenlabs: + AUDIO_TTS_ENGINE=elevenlabs + AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key + AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices + AUDIO_TTS_MODEL=eleven_multilingual_v2 + """ + headers = { + "xi-api-key": app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + } + try: + # TODO: Add retries + response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers) + response.raise_for_status() + voices_data = response.json() + + voices = {} + for voice in voices_data.get("voices", []): + voices[voice["voice_id"]] = voice["name"] + except requests.RequestException as e: + # Avoid @lru_cache with exception + log.error(f"Error fetching voices: {str(e)}") + raise RuntimeError(f"Error fetching voices: {str(e)}") + + return voices @app.get("/voices") async def get_voices(user=Depends(get_verified_user)): - return {"voices": get_available_voices()} + return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index d2f5ddd5d6..371b1c3bc2 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -1,26 +1,10 @@ -import re -import requests -import base64 from fastapi import ( FastAPI, Request, Depends, HTTPException, - status, - UploadFile, - File, - Form, ) from fastapi.middleware.cors import CORSMiddleware - -from constants import ERROR_MESSAGES -from utils.utils import ( - get_verified_user, - get_admin_user, -) - -from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image -from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel from pathlib import Path @@ -29,7 +13,21 @@ import uuid import base64 import json import logging +import re +import requests +from utils.utils import ( + get_verified_user, + get_admin_user, +) + +from apps.images.utils.comfyui import ( + ComfyUIWorkflow, + ComfyUIGenerateImageForm, + comfyui_generate_image, +) + +from constants import ERROR_MESSAGES from config import ( SRC_LOG_LEVELS, CACHE_DIR, @@ -38,18 +36,14 @@ from config import ( AUTOMATIC1111_BASE_URL, AUTOMATIC1111_API_AUTH, COMFYUI_BASE_URL, - COMFYUI_CFG_SCALE, - COMFYUI_SAMPLER, - COMFYUI_SCHEDULER, - COMFYUI_SD3, - COMFYUI_FLUX, - COMFYUI_FLUX_WEIGHT_DTYPE, - COMFYUI_FLUX_FP8_CLIP, + COMFYUI_WORKFLOW, + COMFYUI_WORKFLOW_NODES, IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_KEY, IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, + CORS_ALLOW_ORIGIN, AppConfig, ) @@ -62,7 +56,7 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -81,16 +75,94 @@ app.state.config.MODEL = IMAGE_GENERATION_MODEL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL +app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW +app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES app.state.config.IMAGE_SIZE = IMAGE_SIZE app.state.config.IMAGE_STEPS = IMAGE_STEPS -app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE -app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER -app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER -app.state.config.COMFYUI_SD3 = COMFYUI_SD3 -app.state.config.COMFYUI_FLUX = COMFYUI_FLUX -app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE -app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP + + +@app.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return { + "enabled": app.state.config.ENABLED, + "engine": app.state.config.ENGINE, + "openai": { + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + }, + "automatic1111": { + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + }, + "comfyui": { + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + }, + } + + +class OpenAIConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + + +class Automatic1111ConfigForm(BaseModel): + AUTOMATIC1111_BASE_URL: str + AUTOMATIC1111_API_AUTH: str + + +class ComfyUIConfigForm(BaseModel): + COMFYUI_BASE_URL: str + COMFYUI_WORKFLOW: str + COMFYUI_WORKFLOW_NODES: list[dict] + + +class ConfigForm(BaseModel): + enabled: bool + engine: str + openai: OpenAIConfigForm + automatic1111: Automatic1111ConfigForm + comfyui: ComfyUIConfigForm + + +@app.post("/config/update") +async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)): + app.state.config.ENGINE = form_data.engine + app.state.config.ENABLED = form_data.enabled + + app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL + app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + + app.state.config.AUTOMATIC1111_BASE_URL = ( + form_data.automatic1111.AUTOMATIC1111_BASE_URL + ) + app.state.config.AUTOMATIC1111_API_AUTH = ( + form_data.automatic1111.AUTOMATIC1111_API_AUTH + ) + + app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL + app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW + app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES + + return { + "enabled": app.state.config.ENABLED, + "engine": app.state.config.ENGINE, + "openai": { + "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, + }, + "automatic1111": { + "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, + "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, + }, + "comfyui": { + "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, + "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW, + "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES, + }, + } def get_automatic1111_api_auth(): @@ -103,166 +175,110 @@ def get_automatic1111_api_auth(): return f"Basic {auth1111_base64_encoded_string}" -@app.get("/config") -async def get_config(request: Request, user=Depends(get_admin_user)): - return { - "engine": app.state.config.ENGINE, - "enabled": app.state.config.ENABLED, - } - - -class ConfigUpdateForm(BaseModel): - engine: str - enabled: bool - - -@app.post("/config/update") -async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.config.ENGINE = form_data.engine - app.state.config.ENABLED = form_data.enabled - return { - "engine": app.state.config.ENGINE, - "enabled": app.state.config.ENABLED, - } - - -class EngineUrlUpdateForm(BaseModel): - AUTOMATIC1111_BASE_URL: Optional[str] = None - AUTOMATIC1111_API_AUTH: Optional[str] = None - COMFYUI_BASE_URL: Optional[str] = None - - -@app.get("/url") -async def get_engine_url(user=Depends(get_admin_user)): - return { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - } - - -@app.post("/url/update") -async def update_engine_url( - form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) -): - if form_data.AUTOMATIC1111_BASE_URL is None: - app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL - else: - url = form_data.AUTOMATIC1111_BASE_URL.strip("/") +@app.get("/config/url/verify") +async def verify_url(user=Depends(get_admin_user)): + if app.state.config.ENGINE == "automatic1111": try: - r = requests.head(url) + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth()}, + ) r.raise_for_status() - app.state.config.AUTOMATIC1111_BASE_URL = url + return True except Exception as e: + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - - if form_data.COMFYUI_BASE_URL is None: - app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL - else: - url = form_data.COMFYUI_BASE_URL.strip("/") - + elif app.state.config.ENGINE == "comfyui": try: - r = requests.head(url) + r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") r.raise_for_status() - app.state.config.COMFYUI_BASE_URL = url + return True except Exception as e: + app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) - - if form_data.AUTOMATIC1111_API_AUTH is None: - app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH else: - app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH + return True + +def set_image_model(model: str): + app.state.config.MODEL = model + if app.state.config.ENGINE in ["", "automatic1111"]: + api_auth = get_automatic1111_api_auth() + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": api_auth}, + ) + options = r.json() + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + json=options, + headers={"authorization": api_auth}, + ) + return app.state.config.MODEL + + +def get_image_model(): + if app.state.config.ENGINE == "openai": + return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" + elif app.state.config.ENGINE == "comfyui": + return app.state.config.MODEL if app.state.config.MODEL else "" + elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "": + try: + r = requests.get( + url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", + headers={"authorization": get_automatic1111_api_auth()}, + ) + options = r.json() + return options["sd_model_checkpoint"] + except Exception as e: + app.state.config.ENABLED = False + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) + + +class ImageConfigForm(BaseModel): + MODEL: str + IMAGE_SIZE: str + IMAGE_STEPS: int + + +@app.get("/image/config") +async def get_image_config(user=Depends(get_admin_user)): return { - "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, - "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH, - "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, - "status": True, + "MODEL": app.state.config.MODEL, + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, } -class OpenAIConfigUpdateForm(BaseModel): - url: str - key: str +@app.post("/image/config/update") +async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)): + app.state.config.MODEL = form_data.MODEL - -@app.get("/openai/config") -async def get_openai_config(user=Depends(get_admin_user)): - return { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - } - - -@app.post("/openai/config/update") -async def update_openai_config( - form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) -): - if form_data.key == "": - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - - app.state.config.OPENAI_API_BASE_URL = form_data.url - app.state.config.OPENAI_API_KEY = form_data.key - - return { - "status": True, - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - } - - -class ImageSizeUpdateForm(BaseModel): - size: str - - -@app.get("/size") -async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE} - - -@app.post("/size/update") -async def update_image_size( - form_data: ImageSizeUpdateForm, user=Depends(get_admin_user) -): - pattern = r"^\d+x\d+$" # Regular expression pattern - if re.match(pattern, form_data.size): - app.state.config.IMAGE_SIZE = form_data.size - return { - "IMAGE_SIZE": app.state.config.IMAGE_SIZE, - "status": True, - } + pattern = r"^\d+x\d+$" + if re.match(pattern, form_data.IMAGE_SIZE): + app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."), ) - -class ImageStepsUpdateForm(BaseModel): - steps: int - - -@app.get("/steps") -async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS} - - -@app.post("/steps/update") -async def update_image_size( - form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) -): - if form_data.steps >= 0: - app.state.config.IMAGE_STEPS = form_data.steps - return { - "IMAGE_STEPS": app.state.config.IMAGE_STEPS, - "status": True, - } + if form_data.IMAGE_STEPS >= 0: + app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."), ) + return { + "MODEL": app.state.config.MODEL, + "IMAGE_SIZE": app.state.config.IMAGE_SIZE, + "IMAGE_STEPS": app.state.config.IMAGE_STEPS, + } + @app.get("/models") def get_models(user=Depends(get_verified_user)): @@ -273,18 +289,50 @@ def get_models(user=Depends(get_verified_user)): {"id": "dall-e-3", "name": "DALL·E 3"}, ] elif app.state.config.ENGINE == "comfyui": - + # TODO - get models from comfyui r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") info = r.json() - return list( - map( - lambda model: {"id": model, "name": model}, - info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], - ) - ) + workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) + model_node_id = None - else: + for node in app.state.config.COMFYUI_WORKFLOW_NODES: + if node["type"] == "model": + model_node_id = node["node_ids"][0] + break + + if model_node_id: + model_list_key = None + + print(workflow[model_node_id]["class_type"]) + for key in info[workflow[model_node_id]["class_type"]]["input"][ + "required" + ]: + if "_name" in key: + model_list_key = key + break + + if model_list_key: + return list( + map( + lambda model: {"id": model, "name": model}, + info[workflow[model_node_id]["class_type"]]["input"][ + "required" + ][model_list_key][0], + ) + ) + else: + return list( + map( + lambda model: {"id": model, "name": model}, + info["CheckpointLoaderSimple"]["input"]["required"][ + "ckpt_name" + ][0], + ) + ) + elif ( + app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + ): r = requests.get( url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models", headers={"authorization": get_automatic1111_api_auth()}, @@ -301,69 +349,11 @@ def get_models(user=Depends(get_verified_user)): raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) -@app.get("/models/default") -async def get_default_model(user=Depends(get_admin_user)): - try: - if app.state.config.ENGINE == "openai": - return { - "model": ( - app.state.config.MODEL if app.state.config.MODEL else "dall-e-2" - ) - } - elif app.state.config.ENGINE == "comfyui": - return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")} - else: - r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": get_automatic1111_api_auth()}, - ) - options = r.json() - return {"model": options["sd_model_checkpoint"]} - except Exception as e: - app.state.config.ENABLED = False - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) - - -class UpdateModelForm(BaseModel): - model: str - - -def set_model_handler(model: str): - if app.state.config.ENGINE in ["openai", "comfyui"]: - app.state.config.MODEL = model - return app.state.config.MODEL - else: - api_auth = get_automatic1111_api_auth() - r = requests.get( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - headers={"authorization": api_auth}, - ) - options = r.json() - - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model - r = requests.post( - url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", - json=options, - headers={"authorization": api_auth}, - ) - - return options - - -@app.post("/models/default/update") -def update_default_model( - form_data: UpdateModelForm, - user=Depends(get_verified_user), -): - return set_model_handler(form_data.model) - - class GenerateImageForm(BaseModel): model: Optional[str] = None prompt: str - n: int = 1 size: Optional[str] = None + n: int = 1 negative_prompt: Optional[str] = None @@ -479,7 +469,6 @@ async def image_generations( return images elif app.state.config.ENGINE == "comfyui": - data = { "prompt": form_data.prompt, "width": width, @@ -493,32 +482,20 @@ async def image_generations( if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt - if app.state.config.COMFYUI_CFG_SCALE: - data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE - - if app.state.config.COMFYUI_SAMPLER is not None: - data["sampler"] = app.state.config.COMFYUI_SAMPLER - - if app.state.config.COMFYUI_SCHEDULER is not None: - data["scheduler"] = app.state.config.COMFYUI_SCHEDULER - - if app.state.config.COMFYUI_SD3 is not None: - data["sd3"] = app.state.config.COMFYUI_SD3 - - if app.state.config.COMFYUI_FLUX is not None: - data["flux"] = app.state.config.COMFYUI_FLUX - - if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None: - data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE - - if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None: - data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP - - data = ImageGenerationPayload(**data) - + form_data = ComfyUIGenerateImageForm( + **{ + "workflow": ComfyUIWorkflow( + **{ + "workflow": app.state.config.COMFYUI_WORKFLOW, + "nodes": app.state.config.COMFYUI_WORKFLOW_NODES, + } + ), + **data, + } + ) res = await comfyui_generate_image( app.state.config.MODEL, - data, + form_data, user.id, app.state.config.COMFYUI_BASE_URL, ) @@ -532,13 +509,15 @@ async def image_generations( file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: - json.dump(data.model_dump(exclude_none=True), f) + json.dump(form_data.model_dump(exclude_none=True), f) log.debug(f"images: {images}") return images - else: + elif ( + app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" + ): if form_data.model: - set_model_handler(form_data.model) + set_image_model(form_data.model) data = { "prompt": form_data.prompt, @@ -560,7 +539,6 @@ async def image_generations( ) res = r.json() - log.debug(f"res: {res}") images = [] @@ -577,7 +555,6 @@ async def image_generations( except Exception as e: error = e - if r != None: data = r.json() if "error" in data: diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index f11dca57c5..cece3d737e 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -15,245 +15,6 @@ from pydantic import BaseModel from typing import Optional -COMFYUI_DEFAULT_PROMPT = """ -{ - "3": { - "inputs": { - "seed": 0, - "steps": 20, - "cfg": 8, - "sampler_name": "euler", - "scheduler": "normal", - "denoise": 1, - "model": [ - "4", - 0 - ], - "positive": [ - "6", - 0 - ], - "negative": [ - "7", - 0 - ], - "latent_image": [ - "5", - 0 - ] - }, - "class_type": "KSampler", - "_meta": { - "title": "KSampler" - } - }, - "4": { - "inputs": { - "ckpt_name": "model.safetensors" - }, - "class_type": "CheckpointLoaderSimple", - "_meta": { - "title": "Load Checkpoint" - } - }, - "5": { - "inputs": { - "width": 512, - "height": 512, - "batch_size": 1 - }, - "class_type": "EmptyLatentImage", - "_meta": { - "title": "Empty Latent Image" - } - }, - "6": { - "inputs": { - "text": "Prompt", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "7": { - "inputs": { - "text": "Negative Prompt", - "clip": [ - "4", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "8": { - "inputs": { - "samples": [ - "3", - 0 - ], - "vae": [ - "4", - 2 - ] - }, - "class_type": "VAEDecode", - "_meta": { - "title": "VAE Decode" - } - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": [ - "8", - 0 - ] - }, - "class_type": "SaveImage", - "_meta": { - "title": "Save Image" - } - } -} -""" - -FLUX_DEFAULT_PROMPT = """ -{ - "5": { - "inputs": { - "width": 1024, - "height": 1024, - "batch_size": 1 - }, - "class_type": "EmptyLatentImage" - }, - "6": { - "inputs": { - "text": "Input Text Here", - "clip": [ - "11", - 0 - ] - }, - "class_type": "CLIPTextEncode" - }, - "8": { - "inputs": { - "samples": [ - "13", - 0 - ], - "vae": [ - "10", - 0 - ] - }, - "class_type": "VAEDecode" - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": [ - "8", - 0 - ] - }, - "class_type": "SaveImage" - }, - "10": { - "inputs": { - "vae_name": "ae.safetensors" - }, - "class_type": "VAELoader" - }, - "11": { - "inputs": { - "clip_name1": "clip_l.safetensors", - "clip_name2": "t5xxl_fp16.safetensors", - "type": "flux" - }, - "class_type": "DualCLIPLoader" - }, - "12": { - "inputs": { - "unet_name": "flux1-dev.safetensors", - "weight_dtype": "default" - }, - "class_type": "UNETLoader" - }, - "13": { - "inputs": { - "noise": [ - "25", - 0 - ], - "guider": [ - "22", - 0 - ], - "sampler": [ - "16", - 0 - ], - "sigmas": [ - "17", - 0 - ], - "latent_image": [ - "5", - 0 - ] - }, - "class_type": "SamplerCustomAdvanced" - }, - "16": { - "inputs": { - "sampler_name": "euler" - }, - "class_type": "KSamplerSelect" - }, - "17": { - "inputs": { - "scheduler": "simple", - "steps": 20, - "denoise": 1, - "model": [ - "12", - 0 - ] - }, - "class_type": "BasicScheduler" - }, - "22": { - "inputs": { - "model": [ - "12", - 0 - ], - "conditioning": [ - "6", - 0 - ] - }, - "class_type": "BasicGuider" - }, - "25": { - "inputs": { - "noise_seed": 778937779713005 - }, - "class_type": "RandomNoise" - } -} -""" - def queue_prompt(prompt, client_id, base_url): log.info("queue_prompt") @@ -311,82 +72,71 @@ def get_images(ws, prompt, client_id, base_url): return {"data": output_images} -class ImageGenerationPayload(BaseModel): +class ComfyUINodeInput(BaseModel): + type: Optional[str] = None + node_ids: list[str] = [] + key: Optional[str] = "text" + value: Optional[str] = None + + +class ComfyUIWorkflow(BaseModel): + workflow: str + nodes: list[ComfyUINodeInput] + + +class ComfyUIGenerateImageForm(BaseModel): + workflow: ComfyUIWorkflow + prompt: str - negative_prompt: Optional[str] = "" - steps: Optional[int] = None - seed: Optional[int] = None + negative_prompt: Optional[str] = None width: int height: int n: int = 1 - cfg_scale: Optional[float] = None - sampler: Optional[str] = None - scheduler: Optional[str] = None - sd3: Optional[bool] = None - flux: Optional[bool] = None - flux_weight_dtype: Optional[str] = None - flux_fp8_clip: Optional[bool] = None + + steps: Optional[int] = None + seed: Optional[int] = None async def comfyui_generate_image( - model: str, payload: ImageGenerationPayload, client_id, base_url + model: str, payload: ComfyUIGenerateImageForm, client_id, base_url ): ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") + workflow = json.loads(payload.workflow.workflow) - comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) - - if payload.cfg_scale: - comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale - - if payload.sampler: - comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler - - if payload.scheduler: - comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler - - if payload.sd3: - comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage" - - if payload.steps: - comfyui_prompt["3"]["inputs"]["steps"] = payload.steps - - comfyui_prompt["4"]["inputs"]["ckpt_name"] = model - comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt - comfyui_prompt["3"]["inputs"]["seed"] = ( - payload.seed if payload.seed else random.randint(0, 18446744073709551614) - ) - - # as Flux uses a completely different workflow, we must treat it specially - if payload.flux: - comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT) - comfyui_prompt["12"]["inputs"]["unet_name"] = model - comfyui_prompt["25"]["inputs"]["noise_seed"] = ( - payload.seed if payload.seed else random.randint(0, 18446744073709551614) - ) - - if payload.sampler: - comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler - - if payload.steps: - comfyui_prompt["17"]["inputs"]["steps"] = payload.steps - - if payload.scheduler: - comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler - - if payload.flux_weight_dtype: - comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype - - if payload.flux_fp8_clip: - comfyui_prompt["11"]["inputs"][ - "clip_name2" - ] = "t5xxl_fp8_e4m3fn.safetensors" - - comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n - comfyui_prompt["5"]["inputs"]["width"] = payload.width - comfyui_prompt["5"]["inputs"]["height"] = payload.height - - # set the text prompt for our positive CLIPTextEncode - comfyui_prompt["6"]["inputs"]["text"] = payload.prompt + for node in payload.workflow.nodes: + if node.type: + if node.type == "model": + for node_id in node.node_ids: + workflow[node_id]["inputs"][node.key] = model + elif node.type == "prompt": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["text"] = payload.prompt + elif node.type == "negative_prompt": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["text"] = payload.negative_prompt + elif node.type == "width": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["width"] = payload.width + elif node.type == "height": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["height"] = payload.height + elif node.type == "n": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["batch_size"] = payload.n + elif node.type == "steps": + for node_id in node.node_ids: + workflow[node_id]["inputs"]["steps"] = payload.steps + elif node.type == "seed": + seed = ( + payload.seed + if payload.seed + else random.randint(0, 18446744073709551614) + ) + for node_id in node.node_ids: + workflow[node.node_id]["inputs"]["seed"] = seed + else: + for node_id in node.node_ids: + workflow[node_id]["inputs"][node.key] = node.value try: ws = websocket.WebSocket() @@ -397,9 +147,7 @@ async def comfyui_generate_image( return None try: - images = await asyncio.to_thread( - get_images, ws, comfyui_prompt, client_id, base_url - ) + images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 03a8e198ee..d3931b1ab9 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -41,6 +41,7 @@ from config import ( MODEL_FILTER_LIST, UPLOAD_DIR, AppConfig, + CORS_ALLOW_ORIGIN, ) from utils.misc import ( calculate_sha256, @@ -55,7 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -147,13 +148,17 @@ async def cleanup_response( await session.close() -async def post_streaming_url(url: str, payload: str, stream: bool = True): +async def post_streaming_url(url: str, payload: Union[str, bytes], stream: bool = True): r = None try: session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - r = await session.post(url, data=payload) + r = await session.post( + url, + data=payload, + headers={"Content-Type": "application/json"}, + ) r.raise_for_status() if stream: @@ -422,6 +427,7 @@ async def copy_model( r = requests.request( method="POST", url=f"{url}/api/copy", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -470,6 +476,7 @@ async def delete_model( r = requests.request( method="DELETE", url=f"{url}/api/delete", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -510,6 +517,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us r = requests.request( method="POST", url=f"{url}/api/show", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -567,6 +575,7 @@ async def generate_embeddings( r = requests.request( method="POST", url=f"{url}/api/embeddings", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -616,6 +625,7 @@ def generate_ollama_embeddings( r = requests.request( method="POST", url=f"{url}/api/embeddings", + headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -721,11 +731,8 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") - - payload = { - **form_data.model_dump(exclude_none=True, exclude=["metadata"]), - } + payload = {**form_data.model_dump(exclude_none=True)} + log.debug(f"{payload = }") if "metadata" in payload: del payload["metadata"] diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index d344c66222..9ad67c40c7 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -32,6 +32,7 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, AppConfig, + CORS_ALLOW_ORIGIN, ) from typing import Optional, Literal, overload @@ -45,7 +46,7 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"]) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f9788556bc..7b2fbc6794 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -129,6 +129,7 @@ from config import ( RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_EMBEDDING_OPENAI_BATCH_SIZE, + CORS_ALLOW_ORIGIN, ) from constants import ERROR_MESSAGES @@ -240,12 +241,9 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) -origins = ["*"] - - app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index dddf3fbb2a..3756151806 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -26,6 +26,7 @@ from utils.misc import ( apply_model_system_prompt_to_body, ) +from utils.tools import get_tools from config import ( SHOW_ADMIN_DETAILS, @@ -43,10 +44,12 @@ from config import ( JWT_EXPIRES_IN, WEBUI_BANNERS, ENABLE_COMMUNITY_SHARING, + ENABLE_MESSAGE_RATING, AppConfig, OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_EMAIL_CLAIM, + CORS_ALLOW_ORIGIN, ) from apps.socket.main import get_event_call, get_event_emitter @@ -59,8 +62,6 @@ from pydantic import BaseModel app = FastAPI() -origins = ["*"] - app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP @@ -82,6 +83,7 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM @@ -93,7 +95,7 @@ app.state.FUNCTIONS = {} app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -274,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}): async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", None) + metadata = form_data.pop("metadata", {}) + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) __event_emitter__ = None __event_call__ = None @@ -286,6 +290,20 @@ async def generate_function_chat_completion(form_data, user): __event_call__ = get_event_call(metadata) __task__ = metadata.get("task", None) + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + } + tools_params = { + **extra_params, + "__model__": app.state.MODELS[form_data["model"]], + "__messages__": form_data["messages"], + "__files__": files, + } + configured_tools = get_tools(app, tool_ids, user, tools_params) + + extra_params["__tools__"] = configured_tools if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id @@ -298,16 +316,7 @@ async def generate_function_chat_completion(form_data, user): function_module = get_function_module(pipe_id) pipe = function_module.pipe - params = get_function_params( - function_module, - form_data, - user, - { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - }, - ) + params = get_function_params(function_module, form_data, user, extra_params) if form_data["stream"]: diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 36dfa4f855..b6e85e2ca2 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,12 +1,10 @@ -from pydantic import BaseModel, ConfigDict, parse_obj_as -from typing import Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import Optional import time from sqlalchemy import String, Column, BigInteger, Text -from utils.misc import get_gravatar_url - -from apps.webui.internal.db import Base, JSONField, Session, get_db +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.chats import Chats #################### @@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel): class UsersTable: - def insert_new_user( self, id: str, @@ -122,7 +119,6 @@ class UsersTable: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user) except Exception: @@ -131,7 +127,6 @@ class UsersTable: def get_user_by_email(self, email: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: @@ -140,7 +135,6 @@ class UsersTable: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user) except Exception: @@ -195,7 +189,6 @@ class UsersTable: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: - db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index e2d6a5036f..c1f46293d1 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -352,6 +352,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, + "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, } @@ -361,6 +362,7 @@ class AdminConfig(BaseModel): DEFAULT_USER_ROLE: str JWT_EXPIRES_IN: str ENABLE_COMMUNITY_SHARING: bool + ENABLE_MESSAGE_RATING: bool @router.post("/admin/config") @@ -382,6 +384,7 @@ async def update_admin_config( request.app.state.config.ENABLE_COMMUNITY_SHARING = ( form_data.ENABLE_COMMUNITY_SHARING ) + request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, @@ -389,6 +392,7 @@ async def update_admin_config( "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, + "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, } diff --git a/backend/apps/webui/routers/utils.py b/backend/apps/webui/routers/utils.py index 7a3c339324..8bf8267da1 100644 --- a/backend/apps/webui/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -85,9 +85,10 @@ async def download_chat_as_pdf( pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf") pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf") pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf") + pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf") pdf.set_font("NotoSans", size=12) - pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP"]) + pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"]) pdf.set_auto_page_break(auto=True, margin=15) diff --git a/backend/config.py b/backend/config.py index 07ee06a58c..0ffacca1be 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,6 +3,8 @@ import sys import logging import importlib.metadata import pkgutil +from urllib.parse import urlparse + import chromadb from chromadb import Settings from bs4 import BeautifulSoup @@ -174,7 +176,6 @@ for version in soup.find_all("h2"): CHANGELOG = changelog_json - #################################### # SAFE_MODE #################################### @@ -806,10 +807,24 @@ USER_PERMISSIONS_CHAT_DELETION = ( os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" ) +USER_PERMISSIONS_CHAT_EDITING = ( + os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_TEMPORARY = ( + os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true" +) + USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", "ui.user_permissions", - {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, + { + "chat": { + "deletion": USER_PERMISSIONS_CHAT_DELETION, + "editing": USER_PERMISSIONS_CHAT_EDITING, + "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, + } + }, ) ENABLE_MODEL_FILTER = PersistentConfig( @@ -840,6 +855,47 @@ ENABLE_COMMUNITY_SHARING = PersistentConfig( os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", ) +ENABLE_MESSAGE_RATING = PersistentConfig( + "ENABLE_MESSAGE_RATING", + "ui.enable_message_rating", + os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", +) + + +def validate_cors_origins(origins): + for origin in origins: + if origin != "*": + validate_cors_origin(origin) + + +def validate_cors_origin(origin): + parsed_url = urlparse(origin) + + # Check if the scheme is either http or https + if parsed_url.scheme not in ["http", "https"]: + raise ValueError( + f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed." + ) + + # Ensure that the netloc (domain + port) is present, indicating it's a valid URL + if not parsed_url.netloc: + raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.") + + +# For production, you should only need one host as +# fastapi serves the svelte-kit built frontend and backend from the same host and port. +# To test CORS_ALLOW_ORIGIN locally, you can set something like +# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080 +# in your .env file depending on your frontend port, 5173 in this case. +CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";") + +if "*" in CORS_ALLOW_ORIGIN: + log.warning( + "\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n" + ) + +validate_cors_origins(CORS_ALLOW_ORIGIN) + class BannerModel(BaseModel): id: str @@ -895,10 +951,7 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( "task.title.prompt_template", os.environ.get( "TITLE_GENERATION_PROMPT_TEMPLATE", - """Here is the query: -{{prompt:middletruncate:8000}} - -Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. Examples of titles: 📉 Stock Market Trends @@ -906,7 +959,9 @@ Examples of titles: Evolution of Music Streaming Remote Work Productivity Tips Artificial Intelligence in Healthcare -🎮 Video Game Development Insights""", +🎮 Video Game Development Insights + +Prompt: {{prompt:middletruncate:8000}}""", ), ) @@ -939,8 +994,7 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( "task.tools.prompt_template", os.environ.get( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", - """Tools: {{TOOLS}} -If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""", + """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""", ), ) @@ -1056,7 +1110,7 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig( "RAG_EMBEDDING_OPENAI_BATCH_SIZE", "rag.embedding_openai_batch_size", - os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1), + int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")), ) RAG_RERANKING_MODEL = PersistentConfig( @@ -1288,46 +1342,127 @@ COMFYUI_BASE_URL = PersistentConfig( os.getenv("COMFYUI_BASE_URL", ""), ) -COMFYUI_CFG_SCALE = PersistentConfig( - "COMFYUI_CFG_SCALE", - "image_generation.comfyui.cfg_scale", - os.getenv("COMFYUI_CFG_SCALE", ""), +COMFYUI_DEFAULT_WORKFLOW = """ +{ + "3": { + "inputs": { + "seed": 0, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "model.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "Prompt", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} +""" + + +COMFYUI_WORKFLOW = PersistentConfig( + "COMFYUI_WORKFLOW", + "image_generation.comfyui.workflow", + os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW), ) -COMFYUI_SAMPLER = PersistentConfig( - "COMFYUI_SAMPLER", - "image_generation.comfyui.sampler", - os.getenv("COMFYUI_SAMPLER", ""), -) - -COMFYUI_SCHEDULER = PersistentConfig( - "COMFYUI_SCHEDULER", - "image_generation.comfyui.scheduler", - os.getenv("COMFYUI_SCHEDULER", ""), -) - -COMFYUI_SD3 = PersistentConfig( - "COMFYUI_SD3", - "image_generation.comfyui.sd3", - os.environ.get("COMFYUI_SD3", "").lower() == "true", -) - -COMFYUI_FLUX = PersistentConfig( - "COMFYUI_FLUX", - "image_generation.comfyui.flux", - os.environ.get("COMFYUI_FLUX", "").lower() == "true", -) - -COMFYUI_FLUX_WEIGHT_DTYPE = PersistentConfig( - "COMFYUI_FLUX_WEIGHT_DTYPE", - "image_generation.comfyui.flux_weight_dtype", - os.getenv("COMFYUI_FLUX_WEIGHT_DTYPE", ""), -) - -COMFYUI_FLUX_FP8_CLIP = PersistentConfig( - "COMFYUI_FLUX_FP8_CLIP", - "image_generation.comfyui.flux_fp8_clip", - os.environ.get("COMFYUI_FLUX_FP8_CLIP", "").lower() == "true", +COMFYUI_WORKFLOW_NODES = PersistentConfig( + "COMFYUI_WORKFLOW", + "image_generation.comfyui.nodes", + [], ) IMAGES_OPENAI_API_BASE_URL = PersistentConfig( @@ -1410,13 +1545,13 @@ AUDIO_TTS_ENGINE = PersistentConfig( AUDIO_TTS_MODEL = PersistentConfig( "AUDIO_TTS_MODEL", "audio.tts.model", - os.getenv("AUDIO_TTS_MODEL", "tts-1"), + os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model ) AUDIO_TTS_VOICE = PersistentConfig( "AUDIO_TTS_VOICE", "audio.tts.voice", - os.getenv("AUDIO_TTS_VOICE", "alloy"), + os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice ) diff --git a/backend/constants.py b/backend/constants.py index b9c7fc430d..d55216bb5d 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -100,3 +100,4 @@ class TASKS(str, Enum): EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" FUNCTION_CALLING = "function_calling" + MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/main.py b/backend/main.py index d8ce5f5d78..2c0c810c92 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,6 +14,7 @@ import requests import mimetypes import shutil import inspect +from typing import Optional from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles @@ -51,15 +52,13 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import Optional from apps.webui.models.auths import Auths from apps.webui.models.models import Models -from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from apps.webui.models.users import Users +from apps.webui.models.users import Users, UserModel -from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id +from apps.webui.utils import load_function_module_by_id from utils.utils import ( get_admin_user, @@ -68,12 +67,16 @@ from utils.utils import ( get_http_authorization_cred, get_password_hash, create_token, + decode_token, ) from utils.task import ( title_generation_template, search_query_generation_template, tools_function_calling_generation_template, + moa_response_generation_template, ) + +from utils.tools import get_tools from utils.misc import ( get_last_user_message, add_or_update_system_message, @@ -118,6 +121,7 @@ from config import ( WEBUI_SESSION_COOKIE_SECURE, ENABLE_ADMIN_CHAT_ACCESS, AppConfig, + CORS_ALLOW_ORIGIN, ) from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS @@ -208,8 +212,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.MODELS = {} -origins = ["*"] - ################################## # @@ -218,25 +220,6 @@ origins = ["*"] ################################## -async def get_body_and_model_and_user(request): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in app.state.MODELS: - raise Exception("Model not found") - model = app.state.MODELS[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - def get_task_model_id(default_model_id): # Set the task model task_model_id = default_model_id @@ -261,6 +244,7 @@ def get_filter_function_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -282,164 +266,7 @@ def get_filter_function_ids(model): return filter_ids -async def get_function_call_response( - messages, - files, - tool_id, - template, - task_model_id, - user, - __event_emitter__=None, - __event_call__=None, -): - tool = Tools.get_tool_by_id(tool_id) - tools_specs = json.dumps(tool.specs, indent=2) - content = tools_function_calling_generation_template(template, tools_specs) - - user_message = get_last_user_message(messages) - prompt = ( - "History:\n" - + "\n".join( - [ - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ] - ) - + f"\nQuery: {user_message}" - ) - - print(prompt) - - payload = { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "task": str(TASKS.FUNCTION_CALLING), - } - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - raise e - - model = app.state.MODELS[task_model_id] - - response = None - try: - response = await generate_chat_completions(form_data=payload, user=user) - content = None - - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - - if content is None: - return None, None, False - - # Parse the function response - print(f"content: {content}") - result = json.loads(content) - print(result) - - citation = None - - if "name" not in result: - return None, None, False - - # Call the function - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module - - file_handler = False - # check if toolkit_module has file_handler self variable - if hasattr(toolkit_module, "file_handler"): - file_handler = True - print("file_handler: ", file_handler) - - if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) - - function = getattr(toolkit_module, result["name"]) - function_result = None - try: - # Get the signature of the function - sig = inspect.signature(function) - params = result["parameters"] - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - # Call the function with the '__user__' parameter included - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(function): - function_result = await function(**params) - else: - function_result = function(**params) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } - except Exception as e: - print(e) - - # Add the function result to the system prompt - if function_result is not None: - return function_result, citation, file_handler - except Exception as e: - print(f"Error: {e}") - - return None, None, False - - -async def chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ -): +async def chat_completion_filter_functions_handler(body, model, extra_params): skip_files = None filter_ids = get_filter_function_ids(model) @@ -475,37 +302,20 @@ async def chat_completion_functions_handler( params = {"body": body} # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - + custom_params = {**extra_params, "__model__": model, "__id__": filter_id} + if hasattr(function_module, "UserValves") and "__user__" in sig.parameters: try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) + uid = custom_params["__user__"]["id"] + custom_params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) + ) except Exception as e: print(e) - params = {**params, "__user__": __user__} + # Add extra params in contained in function signature + for key, value in custom_params.items(): + if key in sig.parameters: + params[key] = value if inspect.iscoroutinefunction(inlet): body = await inlet(**params) @@ -516,74 +326,146 @@ async def chat_completion_functions_handler( print(f"Error: {e}") raise e - if skip_files: - if "files" in body: - del body["files"] + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] return body, {} -async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): - skip_files = None +def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) - contexts = [] - citations = None + prompt = f"History:\n{history}\nQuery: {user_message}" - task_model_id = get_task_model_id(body["model"]) - - # If tool_ids field is present, call the functions - if "tool_ids" in body: - print(body["tool_ids"]) - for tool_id in body["tool_ids"]: - print(tool_id) - try: - response, citation, file_handler = await get_function_call_response( - messages=body["messages"], - files=body.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - __event_emitter__=__event_emitter__, - __event_call__=__event_call__, - ) - - print(file_handler) - if isinstance(response, str): - contexts.append(response) - - if citation: - if citations is None: - citations = [citation] - else: - citations.append(citation) - - if file_handler: - skip_files = True - - except Exception as e: - print(f"Error: {e}") - del body["tool_ids"] - print(f"tool_contexts: {contexts}") - - if skip_files: - if "files" in body: - del body["files"] - - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, } -async def chat_completion_files_handler(body): +async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + +async def chat_completion_tools_handler( + body: dict, user: UserModel, extra_params: dict +) -> tuple[dict, dict]: + # If tool_ids field is present, call the functions + metadata = body.get("metadata", {}) + tool_ids = metadata.get("tool_ids", None) + if not tool_ids: + return body, {} + + skip_files = False contexts = [] - citations = None + citations = [] - if "files" in body: - files = body["files"] - del body["files"] + task_model_id = get_task_model_id(body["model"]) + log.debug(f"{tool_ids=}") + + custom_params = { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + } + tools = get_tools(webui_app, tool_ids, user, custom_params) + log.info(f"{tools=}") + + specs = [tool["spec"] for tool in tools.values()] + tools_specs = json.dumps(specs) + + tools_function_calling_prompt = tools_function_calling_generation_template( + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs + ) + log.info(f"{tools_function_calling_prompt=}") + payload = get_tools_function_calling_payload( + body["messages"], task_model_id, tools_function_calling_prompt + ) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e + + try: + response = await generate_chat_completions(form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + + if not content: + return body, {} + + result = json.loads(content) + + tool_function_name = result.get("name", None) + if tool_function_name not in tools: + return body, {} + + tool_function_params = result.get("parameters", {}) + + try: + tool_output = await tools[tool_function_name]["callable"]( + **tool_function_params + ) + except Exception as e: + tool_output = str(e) + + if tools[tool_function_name]["citation"]: + citations.append( + { + "source": { + "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + }, + "document": [tool_output], + "metadata": [{"source": tool_function_name}], + } + ) + if tools[tool_function_name]["file_handler"]: + skip_files = True + + if isinstance(tool_output, str): + contexts.append(tool_output) + + except Exception as e: + log.exception(f"Error: {e}") + content = None + + log.debug(f"tool_contexts: {contexts}") + + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] + + return body, {"contexts": contexts, "citations": citations} + + +async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: + contexts = [] + citations = [] + + if files := body.get("metadata", {}).get("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], @@ -596,153 +478,163 @@ async def chat_completion_files_handler(body): log.debug(f"rag_contexts: {contexts}, citations: {citations}") - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + return body, {"contexts": contexts, "citations": citations} + + +def is_chat_completion_request(request): + return request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) + + +async def get_body_and_model_and_user(request): + # Read the original request body + body = await request.body() + body_str = body.decode("utf-8") + body = json.loads(body_str) if body_str else {} + + model_id = body["model"] + if model_id not in app.state.MODELS: + raise Exception("Model not found") + model = app.state.MODELS[model_id] + + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) + + return body, model, user class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) + log.debug(f"request.url.path: {request.url.path}") - try: - body, model, user = await get_body_and_model_and_user(request) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, + try: + body, model, user = await get_body_and_model_and_user(request) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "valves": body.pop("valves", None), + "tool_ids": body.pop("tool_ids", None), + "files": body.pop("files", None), + } + body["metadata"] = metadata + + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + extra_params = { + "__user__": __user__, + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + } + + # Initialize data_items to store additional data to be sent to the client + # Initalize contexts and citation + data_items = [] + contexts = [] + citations = [] + + try: + body, flags = await chat_completion_filter_functions_handler( + body, model, extra_params + ) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + try: + body, flags = await chat_completion_tools_handler(body, user, extra_params) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + try: + body, flags = await chat_completion_files_handler(body) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + # If context is not empty, insert it into the messages + if len(contexts) > 0: + context_string = "/n".join(contexts).strip() + prompt = get_last_user_message(body["messages"]) + if prompt is None: + raise Exception("No user message found") + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "valves": body.pop("valves", None), - } - - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - - # Initialize data_items to store additional data to be sent to the client - data_items = [] - - # Initialize context, and citations - contexts = [] - citations = [] - - try: - body, flags = await chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - try: - body, flags = await chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ - ) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - try: - body, flags = await chat_completion_files_handler(body) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - # If context is not empty, insert it into the messages - if len(contexts) > 0: - context_string = "/n".join(contexts).strip() - prompt = get_last_user_message(body["messages"]) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - - # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) - - body["metadata"] = metadata - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] - - response = await call_next(request) - if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers.get("Content-Type") - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) - - return response else: - return response + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + + # If there are citations, add them to the data_items + if len(citations) > 0: + data_items.append({"citations": citations}) + + modified_body_bytes = json.dumps(body).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] - # If it's not a chat completion request, just pass it through response = await call_next(request) - return response + if not isinstance(response, StreamingResponse): + return response + + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response + + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) + + async for data in original_generator: + yield data + + return StreamingResponse(stream_wrapper(response.body_iterator, data_items)) async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} - async def openai_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"data: {json.dumps(item)}\n\n" - - async for data in original_generator: - yield data - - async def ollama_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"{json.dumps(item)}\n" - - async for data in original_generator: - yield data - app.add_middleware(ChatCompletionMiddleware) @@ -790,19 +682,21 @@ def filter_pipeline(payload, user): url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) + if key == "": + continue - r.raise_for_status() - payload = r.json() + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() except Exception as e: # Handle connection error here print(f"Connection error: {e}") @@ -817,44 +711,39 @@ def filter_pipeline(payload, user): class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} + log.debug(f"request.url.path: {request.url.path}") - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + user = get_current_user( + request, + get_http_authorization_cred(request.headers["Authorization"]), + ) + + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, ) - try: - data = filter_pipeline(data, user) - except Exception as e: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] + modified_body_bytes = json.dumps(data).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] response = await call_next(request) return response @@ -868,7 +757,7 @@ app.add_middleware(PipelineMiddleware) app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -1019,6 +908,8 @@ async def get_all_models(): model["actions"] = [] for action_id in action_ids: action = Functions.get_function_by_id(action_id) + if action is None: + raise Exception(f"Action not found: {action_id}") if action_id in webui_app.state.FUNCTIONS: function_module = webui_app.state.FUNCTIONS[action_id] @@ -1026,6 +917,10 @@ async def get_all_models(): function_module, _, _ = load_function_module_by_id(action_id) webui_app.state.FUNCTIONS[action_id] = function_module + __webui__ = False + if hasattr(function_module, "__webui__"): + __webui__ = function_module.__webui__ + if hasattr(function_module, "actions"): actions = function_module.actions model["actions"].extend( @@ -1039,6 +934,7 @@ async def get_all_models(): "icon_url": _action.get( "icon_url", action.meta.manifest.get("icon_url", None) ), + **({"__webui__": __webui__} if __webui__ else {}), } for _action in actions ] @@ -1050,6 +946,7 @@ async def get_all_models(): "name": action.name, "description": action.meta.description, "icon_url": action.meta.manifest.get("icon_url", None), + **({"__webui__": __webui__} if __webui__ else {}), } ) @@ -1092,23 +989,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) model = app.state.MODELS[model_id] - - # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. - task = None - if "task" in form_data: - task = form_data["task"] - del form_data["task"] - - if task: - if "metadata" in form_data: - form_data["metadata"]["task"] = task - else: - form_data["metadata"] = {"task": task} - if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": - print("generate_ollama_chat_completion") return await generate_ollama_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user) @@ -1192,6 +1075,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -1481,7 +1365,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.TITLE_GENERATION), + "metadata": {"task": str(TASKS.TITLE_GENERATION)}, } log.debug(payload) @@ -1534,7 +1418,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": str(TASKS.QUERY_GENERATION), + "metadata": {"task": str(TASKS.QUERY_GENERATION)}, } print(payload) @@ -1591,7 +1475,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.EMOJI_GENERATION), + "metadata": {"task": str(TASKS.EMOJI_GENERATION)}, } log.debug(payload) @@ -1610,9 +1494,9 @@ Message: """{{prompt}}""" return await generate_chat_completions(form_data=payload, user=user) -@app.post("/api/task/tools/completions") -async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): - print("get_tools_function_calling") +@app.post("/api/task/moa/completions") +async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): + print("generate_moa_response") model_id = form_data["model"] if model_id not in app.state.MODELS: @@ -1624,26 +1508,43 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ # Check if the user has a custom task model # If the user has a custom task model, use that model model_id = get_task_model_id(model_id) - print(model_id) - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)}, + } + + log.debug(payload) try: - context, _, _ = await get_function_call_response( - form_data["messages"], - form_data.get("files", []), - form_data["tool_id"], - template, - model_id, - user, - ) - return context + payload = filter_pipeline(payload, user) except Exception as e: return JSONResponse( status_code=e.args[0], content={"detail": e.args[1]}, ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + ################################## # @@ -1683,7 +1584,7 @@ async def upload_pipeline( ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file - if not file.filename.endswith(".py"): + if not (file.filename and file.filename.endswith(".py")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Only Python (.py) files are allowed.", @@ -1980,40 +1881,61 @@ async def update_pipeline_valves( @app.get("/api/config") -async def get_app_config(): +async def get_app_config(request: Request): + user = None + if "token" in request.cookies: + token = request.cookies.get("token") + data = decode_token(token) + if data is not None and "id" in data: + user = Users.get_user_by_id(data["id"]) + return { "status": True, "name": WEBUI_NAME, "version": VERSION, "default_locale": str(DEFAULT_LOCALE), - "default_models": webui_app.state.config.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, - "features": { - "auth": WEBUI_AUTH, - "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), - "enable_signup": webui_app.state.config.ENABLE_SIGNUP, - "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, - "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_image_generation": images_app.state.config.ENABLED, - "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, - "enable_admin_export": ENABLE_ADMIN_EXPORT, - "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, - }, - "audio": { - "tts": { - "engine": audio_app.state.config.TTS_ENGINE, - "voice": audio_app.state.config.TTS_VOICE, - }, - "stt": { - "engine": audio_app.state.config.STT_ENGINE, - }, - }, "oauth": { "providers": { name: config.get("name", name) for name, config in OAUTH_PROVIDERS.items() } }, + "features": { + "auth": WEBUI_AUTH, + "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_signup": webui_app.state.config.ENABLE_SIGNUP, + "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, + **( + { + "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_image_generation": images_app.state.config.ENABLED, + "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, + "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, + "enable_admin_export": ENABLE_ADMIN_EXPORT, + "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, + } + if user is not None + else {} + ), + }, + **( + { + "default_models": webui_app.state.config.DEFAULT_MODELS, + "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "audio": { + "tts": { + "engine": audio_app.state.config.TTS_ENGINE, + "voice": audio_app.state.config.TTS_VOICE, + }, + "stt": { + "engine": audio_app.state.config.STT_ENGINE, + }, + }, + "permissions": {**webui_app.state.config.USER_PERMISSIONS}, + } + if user is not None + else {} + ), } @@ -2132,7 +2054,10 @@ async def oauth_login(provider: str, request: Request): redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( "oauth_callback", provider=provider ) - return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + client = oauth.create_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) # OAuth login logic is as follows: @@ -2264,7 +2189,20 @@ async def get_manifest_json(): "display": "standalone", "background_color": "#343541", "orientation": "portrait-primary", - "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}], + "icons": [ + { + "src": "/static/logo.png", + "type": "image/png", + "sizes": "500x500", + "purpose": "any", + }, + { + "src": "/static/logo.png", + "type": "image/png", + "sizes": "500x500", + "purpose": "maskable", + }, + ], } diff --git a/backend/requirements.txt b/backend/requirements.txt index 6ef299b5fa..04b3261916 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,5 +1,5 @@ fastapi==0.111.0 -uvicorn[standard]==0.22.0 +uvicorn[standard]==0.30.6 pydantic==2.8.2 python-multipart==0.0.9 @@ -13,17 +13,17 @@ passlib[bcrypt]==1.7.4 requests==2.32.3 aiohttp==3.10.2 -sqlalchemy==2.0.31 +sqlalchemy==2.0.32 alembic==1.13.2 peewee==3.17.6 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 PyMySQL==1.1.1 -bcrypt==4.1.3 +bcrypt==4.2.0 pymongo redis -boto3==1.34.153 +boto3==1.35.0 argon2-cffi==23.1.0 APScheduler==3.10.4 @@ -44,7 +44,7 @@ sentence-transformers==3.0.1 pypdf==4.3.1 docx2txt==0.8 python-pptx==1.0.0 -unstructured==0.15.0 +unstructured==0.15.5 Markdown==3.6 pypandoc==1.13 pandas==2.2.2 @@ -60,7 +60,7 @@ rapidocr-onnxruntime==1.3.24 fpdf2==2.7.9 rank-bm25==0.2.2 -faster-whisper==1.0.2 +faster-whisper==1.0.3 PyJWT[crypto]==2.9.0 authlib==1.3.1 diff --git a/backend/static/fonts/NotoSansSC-Regular.ttf b/backend/static/fonts/NotoSansSC-Regular.ttf new file mode 100644 index 0000000000..7056f5e97a Binary files /dev/null and b/backend/static/fonts/NotoSansSC-Regular.ttf differ diff --git a/backend/utils/schemas.py b/backend/utils/schemas.py new file mode 100644 index 0000000000..09b24897b9 --- /dev/null +++ b/backend/utils/schemas.py @@ -0,0 +1,104 @@ +from pydantic import BaseModel, Field, create_model +from typing import Any, Optional, Type + + +def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: + """ + Converts a JSON schema to a Pydantic BaseModel class. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic BaseModel class. + """ + + # Extract the model name from the schema title. + model_name = tool_dict["name"] + schema = tool_dict["parameters"] + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field(name, prop, schema.get("required", [])) + for name, prop in schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions) + + +def json_schema_to_pydantic_field( + name: str, json_schema: dict[str, Any], required: list[str] +) -> Any: + """ + Converts a JSON schema property to a Pydantic field definition. + + Args: + name: The field name. + json_schema: The JSON schema property. + + Returns: + A Pydantic field definition. + """ + + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The 'required' flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: + """ + Converts a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + """ + + type_ = json_schema.get("type") + + if type_ == "string" or type_ == "str": + return str + elif type_ == "integer" or type_ == "int": + return int + elif type_ == "number" or type_ == "float": + return float + elif type_ == "boolean" or type_ == "bool": + return bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema) + return list[item_type] + else: + return list + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + return nested_model + else: + return dict + elif type_ == "null": + return Optional[Any] # Use Optional[Any] for nullable fields + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") diff --git a/backend/utils/task.py b/backend/utils/task.py index 1b2276c9c5..ea9254c4f7 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -121,6 +121,43 @@ def search_query_generation_template( return template +def moa_response_generation_template( + template: str, prompt: str, responses: list[str] +) -> str: + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + responses = [f'"""{response}"""' for response in responses] + responses = "\n\n".join(responses) + + template = template.replace("{{responses}}", responses) + return template + + def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template diff --git a/backend/utils/tools.py b/backend/utils/tools.py index eac36b5d90..1a2fea32b0 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -1,5 +1,90 @@ import inspect -from typing import get_type_hints +import logging +from typing import Awaitable, Callable, get_type_hints + +from apps.webui.models.tools import Tools +from apps.webui.models.users import UserModel +from apps.webui.utils import load_toolkit_module_by_id + +from utils.schemas import json_schema_to_model + +log = logging.getLogger(__name__) + + +def apply_extra_params_to_tool_function( + function: Callable, extra_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(function) + extra_params = { + key: value for key, value in extra_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(function) + + async def new_function(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await function(**extra_kwargs) + return function(**extra_kwargs) + + return new_function + + +# Mutation on extra_params +def get_tools( + webui_app, tool_ids: list[str], user: UserModel, extra_params: dict +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + extra_params["__id__"] = tool_id + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + function_name = spec["name"] + + # convert to function that takes only model params and inserts custom params + original_func = getattr(module, function_name) + callable = apply_extra_params_to_tool_function(original_func, extra_params) + if hasattr(original_func, "__doc__"): + callable.__doc__ = original_func.__doc__ + + # TODO: This needs to be a pydantic model + tool_dict = { + "toolkit_id": tool_id, + "callable": callable, + "spec": spec, + "pydantic_model": json_schema_to_model(spec), + "file_handler": hasattr(module, "file_handler") and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools: + log.warning(f"Tool {function_name} already exists in another toolkit!") + log.warning(f"Collision between {toolkit} and {tool_id}.") + log.warning(f"Discarding {toolkit}.{function_name}") + else: + tools[function_name] = tool_dict + return tools def doc_to_dict(docstring): diff --git a/package-lock.json b/package-lock.json index 52c5f89335..88d4d89281 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.3.13", + "version": "0.3.14", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.13", + "version": "0.3.14", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -18,6 +18,7 @@ "codemirror": "^6.0.1", "crc-32": "^1.2.2", "dayjs": "^1.11.10", + "dompurify": "^3.1.6", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", "fuse.js": "^7.0.0", @@ -29,7 +30,6 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", - "marked-katex-extension": "^5.1.1", "mermaid": "^10.9.1", "pyodide": "^0.26.1", "socket.io-client": "^4.2.0", @@ -1545,11 +1545,6 @@ "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, - "node_modules/@types/katex": { - "version": "0.16.7", - "resolved": "https://registry.npmjs.org/@types/katex/-/katex-0.16.7.tgz", - "integrity": "sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==" - }, "node_modules/@types/mdast": { "version": "3.0.15", "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-3.0.15.tgz", @@ -3918,9 +3913,9 @@ } }, "node_modules/dompurify": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.1.5.tgz", - "integrity": "sha512-lwG+n5h8QNpxtyrJW/gJWckL+1/DQiYMX8f7t8Z2AZTPw1esVrqjI63i7Zc2Gz0aKzLVMYC1V1PL/ky+aY/NgA==" + "version": "3.1.6", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.1.6.tgz", + "integrity": "sha512-cTOAhc36AalkjtBpfG6O8JimdTMWNXjiePT2xQH/ppBGi/4uIpmj8eKyIkMJErXWARyINV/sB38yf8JCLF5pbQ==" }, "node_modules/domutils": { "version": "3.1.0", @@ -6042,18 +6037,6 @@ "node": ">= 16" } }, - "node_modules/marked-katex-extension": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/marked-katex-extension/-/marked-katex-extension-5.1.1.tgz", - "integrity": "sha512-piquiCyZpZ1aiocoJlJkRXr+hkk5UI4xw9GhRZiIAAgvX5rhzUDSJ0seup1JcsgueC8MLNDuqe5cRcAzkFE42Q==", - "dependencies": { - "@types/katex": "^0.16.7" - }, - "peerDependencies": { - "katex": ">=0.16 <0.17", - "marked": ">=4 <15" - } - }, "node_modules/matcher-collection": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/matcher-collection/-/matcher-collection-2.0.1.tgz", diff --git a/package.json b/package.json index 2d32422d1b..08cf84b313 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.13", + "version": "0.3.14", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -59,6 +59,7 @@ "codemirror": "^6.0.1", "crc-32": "^1.2.2", "dayjs": "^1.11.10", + "dompurify": "^3.1.6", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", "fuse.js": "^7.0.0", @@ -70,7 +71,6 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", - "marked-katex-extension": "^5.1.1", "mermaid": "^10.9.1", "pyodide": "^0.26.1", "socket.io-client": "^4.2.0", diff --git a/pyproject.toml b/pyproject.toml index 159bce0727..61c3e5417c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ license = { file = "LICENSE" } dependencies = [ "fastapi==0.111.0", - "uvicorn[standard]==0.22.0", + "uvicorn[standard]==0.30.6", "pydantic==2.8.2", "python-multipart==0.0.9", @@ -21,17 +21,17 @@ dependencies = [ "requests==2.32.3", "aiohttp==3.10.2", - "sqlalchemy==2.0.31", + "sqlalchemy==2.0.32", "alembic==1.13.2", "peewee==3.17.6", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", "PyMySQL==1.1.1", - "bcrypt==4.1.3", + "bcrypt==4.2.0", "pymongo", "redis", - "boto3==1.34.153", + "boto3==1.35.0", "argon2-cffi==23.1.0", "APScheduler==3.10.4", @@ -51,7 +51,7 @@ dependencies = [ "pypdf==4.3.1", "docx2txt==0.8", "python-pptx==1.0.0", - "unstructured==0.15.0", + "unstructured==0.15.5", "Markdown==3.6", "pypandoc==1.13", "pandas==2.2.2", @@ -67,7 +67,7 @@ dependencies = [ "fpdf2==2.7.9", "rank-bm25==0.2.2", - "faster-whisper==1.0.2", + "faster-whisper==1.0.3", "PyJWT[crypto]==2.9.0", "authlib==1.3.1", diff --git a/requirements-dev.lock b/requirements-dev.lock index 6b3f518512..01dcaa2c3c 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -50,7 +50,7 @@ backoff==2.2.1 # via langfuse # via posthog # via unstructured -bcrypt==4.1.3 +bcrypt==4.2.0 # via chromadb # via open-webui # via passlib @@ -63,9 +63,9 @@ black==24.8.0 # via open-webui blinker==1.8.2 # via flask -boto3==1.34.153 +boto3==1.35.0 # via open-webui -botocore==1.34.155 +botocore==1.35.2 # via boto3 # via s3transfer build==1.2.1 @@ -156,7 +156,7 @@ fastapi==0.111.0 # via open-webui fastapi-cli==0.0.4 # via fastapi -faster-whisper==1.0.2 +faster-whisper==1.0.3 # via open-webui filelock==3.14.0 # via huggingface-hub @@ -632,7 +632,7 @@ sniffio==1.3.1 # via openai soupsieve==2.5 # via beautifulsoup4 -sqlalchemy==2.0.31 +sqlalchemy==2.0.32 # via alembic # via langchain # via langchain-community @@ -703,7 +703,7 @@ tzlocal==5.2 # via extract-msg ujson==5.10.0 # via fastapi -unstructured==0.15.0 +unstructured==0.15.5 # via open-webui unstructured-client==0.22.0 # via unstructured @@ -715,7 +715,7 @@ urllib3==2.2.1 # via kubernetes # via requests # via unstructured-client -uvicorn==0.22.0 +uvicorn==0.30.6 # via chromadb # via fastapi # via open-webui diff --git a/requirements.lock b/requirements.lock index 6b3f518512..01dcaa2c3c 100644 --- a/requirements.lock +++ b/requirements.lock @@ -50,7 +50,7 @@ backoff==2.2.1 # via langfuse # via posthog # via unstructured -bcrypt==4.1.3 +bcrypt==4.2.0 # via chromadb # via open-webui # via passlib @@ -63,9 +63,9 @@ black==24.8.0 # via open-webui blinker==1.8.2 # via flask -boto3==1.34.153 +boto3==1.35.0 # via open-webui -botocore==1.34.155 +botocore==1.35.2 # via boto3 # via s3transfer build==1.2.1 @@ -156,7 +156,7 @@ fastapi==0.111.0 # via open-webui fastapi-cli==0.0.4 # via fastapi -faster-whisper==1.0.2 +faster-whisper==1.0.3 # via open-webui filelock==3.14.0 # via huggingface-hub @@ -632,7 +632,7 @@ sniffio==1.3.1 # via openai soupsieve==2.5 # via beautifulsoup4 -sqlalchemy==2.0.31 +sqlalchemy==2.0.32 # via alembic # via langchain # via langchain-community @@ -703,7 +703,7 @@ tzlocal==5.2 # via extract-msg ujson==5.10.0 # via fastapi -unstructured==0.15.0 +unstructured==0.15.5 # via open-webui unstructured-client==0.22.0 # via unstructured @@ -715,7 +715,7 @@ urllib3==2.2.1 # via kubernetes # via requests # via unstructured-client -uvicorn==0.22.0 +uvicorn==0.30.6 # via chromadb # via fastapi # via open-webui diff --git a/src/app.css b/src/app.css index 4345bb3777..a421d90ae4 100644 --- a/src/app.css +++ b/src/app.css @@ -34,6 +34,10 @@ math { @apply rounded-lg; } +.markdown-prose { + @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; +} + .markdown a { @apply underline; } diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index 3f624704eb..2e6510437b 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -1,6 +1,6 @@ import { IMAGES_API_BASE_URL } from '$lib/constants'; -export const getImageGenerationConfig = async (token: string = '') => { +export const getConfig = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/config`, { @@ -32,11 +32,7 @@ export const getImageGenerationConfig = async (token: string = '') => { return res; }; -export const updateImageGenerationConfig = async ( - token: string = '', - engine: string, - enabled: boolean -) => { +export const updateConfig = async (token: string = '', config: object) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, { @@ -47,8 +43,7 @@ export const updateImageGenerationConfig = async ( ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - engine, - enabled + ...config }) }) .then(async (res) => { @@ -72,10 +67,10 @@ export const updateImageGenerationConfig = async ( return res; }; -export const getOpenAIConfig = async (token: string = '') => { +export const verifyConfigUrl = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config/url/verify`, { method: 'GET', headers: { Accept: 'application/json', @@ -104,46 +99,10 @@ export const getOpenAIConfig = async (token: string = '') => { return res; }; -export const updateOpenAIConfig = async (token: string = '', url: string, key: string) => { +export const getImageGenerationConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - url: url, - key: key - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const getImageGenerationEngineUrls = async (token: string = '') => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/url`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/image/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -172,19 +131,17 @@ export const getImageGenerationEngineUrls = async (token: string = '') => { return res; }; -export const updateImageGenerationEngineUrls = async (token: string = '', urls: object = {}) => { +export const updateImageGenerationConfig = async (token: string = '', config: object) => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/image/config/update`, { method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', ...(token && { authorization: `Bearer ${token}` }) }, - body: JSON.stringify({ - ...urls - }) + body: JSON.stringify({ ...config }) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -207,138 +164,6 @@ export const updateImageGenerationEngineUrls = async (token: string = '', urls: return res; }; -export const getImageSize = async (token: string = '') => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/size`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_SIZE; -}; - -export const updateImageSize = async (token: string = '', size: string) => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/size/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - size: size - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_SIZE; -}; - -export const getImageSteps = async (token: string = '') => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/steps`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_STEPS; -}; - -export const updateImageSteps = async (token: string = '', steps: number) => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/steps/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ steps }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.IMAGE_STEPS; -}; - export const getImageGenerationModels = async (token: string = '') => { let error = null; @@ -371,73 +196,6 @@ export const getImageGenerationModels = async (token: string = '') => { return res; }; -export const getDefaultImageGenerationModel = async (token: string = '') => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.model; -}; - -export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => { - let error = null; - - const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - ...(token && { authorization: `Bearer ${token}` }) - }, - body: JSON.stringify({ - model: model - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - console.log(err); - if ('detail' in err) { - error = err.detail; - } else { - error = 'Server connection failed'; - } - return null; - }); - - if (error) { - throw error; - } - - return res.model; -}; - export const imageGenerations = async (token: string = '', prompt: string) => { let error = null; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c4778cadbd..8432554785 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -333,6 +333,42 @@ export const generateSearchQuery = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; }; +export const generateMoACompletion = async ( + token: string = '', + model: string, + prompt: string, + responses: string[] +) => { + const controller = new AbortController(); + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { + signal: controller.signal, + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + responses: responses, + stream: true + }) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return [res, controller]; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; @@ -629,6 +665,7 @@ export const getBackendConfig = async () => { const res = await fetch(`${WEBUI_BASE_URL}/api/config`, { method: 'GET', + credentials: 'include', headers: { 'Content-Type': 'application/json' } @@ -913,6 +950,7 @@ export interface ModelConfig { export interface ModelMeta { description?: string; capabilities?: object; + profile_image_url?: string; } export interface ModelParams {} diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index c4c449156c..d4e994312e 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -396,7 +396,7 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string return res; }; -export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { +export const pullModel = async (token: string, tagName: string, urlIdx: number | null = null) => { let error = null; const controller = new AbortController(); diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index afb8736ea1..e242ab632a 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -336,8 +336,11 @@
{#if selectedTab === 'general'} { + saveHandler={async () => { toast.success($i18n.t('Settings saved successfully!')); + + await tick(); + await config.set(await getBackendConfig()); }} /> {:else if selectedTab === 'users'} diff --git a/src/lib/components/admin/Settings/General.svelte b/src/lib/components/admin/Settings/General.svelte index bc66c2e01f..776b7ff8d0 100644 --- a/src/lib/components/admin/Settings/General.svelte +++ b/src/lib/components/admin/Settings/General.svelte @@ -1,22 +1,10 @@ @@ -155,263 +175,366 @@
{ - loading = true; - - if (imageGenerationEngine === 'openai') { - await updateOpenAIConfig(localStorage.token, OPENAI_API_BASE_URL, OPENAI_API_KEY); - } - - await updateDefaultImageGenerationModel(localStorage.token, selectedModel); - - await updateImageSize(localStorage.token, imageSize).catch((error) => { - toast.error(error); - return null; - }); - await updateImageSteps(localStorage.token, steps).catch((error) => { - toast.error(error); - return null; - }); - - dispatch('save'); - loading = false; + saveHandler(); }} > -
-
-
{$i18n.t('Image Settings')}
- -
-
{$i18n.t('Image Generation Engine')}
-
- -
-
- +
+ {#if config && imageGenerationConfig}
-
-
- {$i18n.t('Image Generation (Experimental)')} +
{$i18n.t('Image Settings')}
+ +
+
+
+ {$i18n.t('Image Generation (Experimental)')} +
+ +
+ { + const enabled = e.detail; + + if (enabled) { + if ( + config.engine === 'automatic1111' && + config.automatic1111.AUTOMATIC1111_BASE_URL === '' + ) { + toast.error($i18n.t('AUTOMATIC1111 Base URL is required.')); + config.enabled = false; + } else if ( + config.engine === 'comfyui' && + config.comfyui.COMFYUI_BASE_URL === '' + ) { + toast.error($i18n.t('ComfyUI Base URL is required.')); + config.enabled = false; + } else if (config.engine === 'openai' && config.openai.OPENAI_API_KEY === '') { + toast.error($i18n.t('OpenAI API Key is required.')); + config.enabled = false; + } + } + + updateConfigHandler(); + }} + /> +
+
- +
+
{$i18n.t('Image Generation Engine')}
+
+ +
-
-
- - {#if imageGenerationEngine === ''} -
{$i18n.t('AUTOMATIC1111 Base URL')}
-
-
- -
- -
- -
- {$i18n.t('Include `--api` flag when running stable-diffusion-webui')} - - {$i18n.t('(e.g. `sh webui.sh --api`)')} - -
- -
{$i18n.t('AUTOMATIC1111 Api Auth String')}
- - -
- {$i18n.t('Include `--api-auth` flag when running stable-diffusion-webui')} - - {$i18n.t('(e.g. `sh webui.sh --api --api-auth username_password`)').replace('_', ':')} - -
- {:else if imageGenerationEngine === 'comfyui'} -
{$i18n.t('ComfyUI Base URL')}
-
-
- -
- -
- {:else if imageGenerationEngine === 'openai'} -
-
{$i18n.t('OpenAI API Config')}
- -
- - - -
-
- {/if} - - {#if enableImageGeneration}
-
-
{$i18n.t('Set Default Model')}
-
-
- {#if imageGenerationEngine === 'openai' && !OPENAI_API_BASE_URL.includes('https://api.openai.com')} +
+ {#if (config?.engine ?? 'automatic1111') === 'automatic1111'} +
+
{$i18n.t('AUTOMATIC1111 Base URL')}
+
+
+ +
+ +
+ +
+ {$i18n.t('Include `--api` flag when running stable-diffusion-webui')} + + {$i18n.t('(e.g. `sh webui.sh --api`)')} + +
+
+ +
+
+ {$i18n.t('AUTOMATIC1111 Api Auth String')} +
+ + +
+ {$i18n.t('Include `--api-auth` flag when running stable-diffusion-webui')} + + {$i18n + .t('(e.g. `sh webui.sh --api --api-auth username_password`)') + .replace('_', ':')} + +
+
+ {:else if config?.engine === 'comfyui'} +
+
{$i18n.t('ComfyUI Base URL')}
+
+
+ +
+ +
+
+ +
+
{$i18n.t('ComfyUI Workflow')}
+ + {#if config.comfyui.COMFYUI_WORKFLOW} +