diff --git a/CHANGELOG.md b/CHANGELOG.md index a3bceff258..d62360f878 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,37 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.11] - 2024-08-02 + +### Added + +- **📊 Model Information Display**: Added visuals for model selection, including images next to model names for more intuitive navigation. +- **🗣 ElevenLabs Voice Adaptations**: Voice enhancements including support for ElevenLabs voice ID by name for personalized vocal interactions. +- **⌨️ Arrow Keys Model Selection**: Users can now use arrow keys for quicker model selection, enhancing accessibility. +- **🔍 Fuzzy Search in Model Selector**: Enhanced model selector with fuzzy search to locate models swiftly, including descriptions. +- **🕹️ ComfyUI Flux Image Generation**: Added support for the new Flux image gen model; introduces environment controls like weight precision and CLIP model options in Settings. +- **💾 Display File Size for Uploads**: Enhanced file interface now displays file size, preparing for upcoming upload restrictions. +- **🎚️ Advanced Params "Min P"**: Added 'Min P' parameter in the advanced settings for customized model precision control. +- **🔒 Enhanced OAuth**: Introduced custom redirect URI support for OAuth behind reverse proxies, enabling safer authentication processes. +- **🖥 Enhanced Latex Rendering**: Adjustments made to latex rendering processes, now accurately detecting and presenting latex inputs from text. +- **🌐 Internationalization**: Enhanced with new Romanian and updated Vietnamese and Ukrainian translations, helping broaden accessibility for international users. + +### Fixed + +- **🔧 Tags Handling in Document Upload**: Tags are now properly sent to the upload document handler, resolving issues with missing metadata. +- **🖥️ Sensitive Input Fields**: Corrected browser misinterpretation of secure input fields, preventing misclassification as password fields. +- **📂 Static Path Resolution in PDF Generation**: Fixed static paths that adjust dynamically to prevent issues across various environments. + +### Changed + +- **🎨 UI/UX Styling Enhancements**: Multiple minor styling updates for a cleaner and more intuitive user interface. +- **🚧 Refactoring Various Components**: Numerous refactoring changes across styling, file handling, and function simplifications for clarity and performance. +- **🎛️ User Valves Management**: Moved user valves from settings to direct chat controls for more user-friendly access during interactions. + +### Removed + +- **⚙️ Health Check Logging**: Removed verbose logging from the health checking processes to declutter logs and improve backend performance. + ## [0.3.10] - 2024-07-17 ### Fixed diff --git a/Dockerfile b/Dockerfile index a217595ea8..8078bf0eac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -151,7 +151,7 @@ COPY --chown=$UID:$GID ./backend . EXPOSE 8080 -HEALTHCHECK CMD curl --silent --fail http://localhost:8080/health | jq -e '.status == true' || exit 1 +HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1 USER $UID:$GID diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index f866d867f1..167db77bae 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -10,12 +10,12 @@ from fastapi import ( File, Form, ) - from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel +from typing import List import uuid import requests import hashlib @@ -31,6 +31,7 @@ from utils.utils import ( ) from utils.misc import calculate_sha256 + from config import ( SRC_LOG_LEVELS, CACHE_DIR, @@ -43,6 +44,7 @@ from config import ( AUDIO_STT_OPENAI_API_KEY, AUDIO_TTS_OPENAI_API_BASE_URL, AUDIO_TTS_OPENAI_API_KEY, + AUDIO_TTS_API_KEY, AUDIO_STT_ENGINE, AUDIO_STT_MODEL, AUDIO_TTS_ENGINE, @@ -75,6 +77,7 @@ app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE app.state.config.TTS_MODEL = AUDIO_TTS_MODEL app.state.config.TTS_VOICE = AUDIO_TTS_VOICE +app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" @@ -87,6 +90,7 @@ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) class TTSConfigForm(BaseModel): OPENAI_API_BASE_URL: str OPENAI_API_KEY: str + API_KEY: str ENGINE: str MODEL: str VOICE: str @@ -137,6 +141,7 @@ async def get_audio_config(user=Depends(get_admin_user)): "tts": { "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": app.state.config.TTS_API_KEY, "ENGINE": app.state.config.TTS_ENGINE, "MODEL": app.state.config.TTS_MODEL, "VOICE": app.state.config.TTS_VOICE, @@ -156,6 +161,7 @@ async def update_audio_config( ): app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + app.state.config.TTS_API_KEY = form_data.tts.API_KEY app.state.config.TTS_ENGINE = form_data.tts.ENGINE app.state.config.TTS_MODEL = form_data.tts.MODEL app.state.config.TTS_VOICE = form_data.tts.VOICE @@ -169,6 +175,7 @@ async def update_audio_config( "tts": { "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": app.state.config.TTS_API_KEY, "ENGINE": app.state.config.TTS_ENGINE, "MODEL": app.state.config.TTS_MODEL, "VOICE": app.state.config.TTS_VOICE, @@ -194,55 +201,111 @@ async def speech(request: Request, user=Depends(get_verified_user)): if file_path.is_file(): return FileResponse(file_path) - headers = {} - headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" - headers["Content-Type"] = "application/json" + if app.state.config.TTS_ENGINE == "openai": + headers = {} + headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" - try: - body = body.decode("utf-8") - body = json.loads(body) - body["model"] = app.state.config.TTS_MODEL - body = json.dumps(body).encode("utf-8") - except Exception as e: - pass + try: + body = body.decode("utf-8") + body = json.loads(body) + body["model"] = app.state.config.TTS_MODEL + body = json.dumps(body).encode("utf-8") + except Exception as e: + pass - r = None - try: - r = requests.post( - url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", - data=body, - headers=headers, - stream=True, - ) + r = None + try: + r = requests.post( + url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", + data=body, + headers=headers, + stream=True, + ) - r.raise_for_status() + r.raise_for_status() - # Save the streaming content to a file - with open(file_path, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) + # Save the streaming content to a file + with open(file_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) - # Return the saved file - return FileResponse(file_path) + # Return the saved file + return FileResponse(file_path) - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except: - error_detail = f"External: {e}" + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message']}" + except: + error_detail = f"External: {e}" - raise HTTPException( - status_code=r.status_code if r != None else 500, - detail=error_detail, - ) + raise HTTPException( + status_code=r.status_code if r != None else 500, + detail=error_detail, + ) + + elif app.state.config.TTS_ENGINE == "elevenlabs": + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + voice_id = payload.get("voice", "") + url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" + + headers = { + "Accept": "audio/mpeg", + "Content-Type": "application/json", + "xi-api-key": app.state.config.TTS_API_KEY, + } + + data = { + "text": payload["input"], + "model_id": app.state.config.TTS_MODEL, + "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, + } + + try: + r = requests.post(url, json=data, headers=headers) + + r.raise_for_status() + + # Save the streaming content to a file + with open(file_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + # Return the saved file + return FileResponse(file_path) + + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message']}" + except: + error_detail = f"External: {e}" + + raise HTTPException( + status_code=r.status_code if r != None else 500, + detail=error_detail, + ) @app.post("/transcriptions") @@ -373,3 +436,69 @@ def transcribe( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) + + +def get_available_models() -> List[dict]: + if app.state.config.TTS_ENGINE == "openai": + return [{"id": "tts-1"}, {"id": "tts-1-hd"}] + elif app.state.config.TTS_ENGINE == "elevenlabs": + headers = { + "xi-api-key": app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + } + + try: + response = requests.get( + "https://api.elevenlabs.io/v1/models", headers=headers + ) + response.raise_for_status() + models = response.json() + return [ + {"name": model["name"], "id": model["model_id"]} for model in models + ] + except requests.RequestException as e: + log.error(f"Error fetching voices: {str(e)}") + return [] + + +@app.get("/models") +async def get_models(user=Depends(get_verified_user)): + return {"models": get_available_models()} + + +def get_available_voices() -> List[dict]: + if app.state.config.TTS_ENGINE == "openai": + return [ + {"name": "alloy", "id": "alloy"}, + {"name": "echo", "id": "echo"}, + {"name": "fable", "id": "fable"}, + {"name": "onyx", "id": "onyx"}, + {"name": "nova", "id": "nova"}, + {"name": "shimmer", "id": "shimmer"}, + ] + elif app.state.config.TTS_ENGINE == "elevenlabs": + headers = { + "xi-api-key": app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + } + + try: + response = requests.get( + "https://api.elevenlabs.io/v1/voices", headers=headers + ) + response.raise_for_status() + voices_data = response.json() + + voices = [] + for voice in voices_data.get("voices", []): + voices.append({"name": voice["name"], "id": voice["voice_id"]}) + return voices + except requests.RequestException as e: + log.error(f"Error fetching voices: {str(e)}") + + return [] + + +@app.get("/voices") +async def get_voices(user=Depends(get_verified_user)): + return {"voices": get_available_voices()} diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 9ae0ad67b7..4239f3f457 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -42,6 +42,9 @@ from config import ( COMFYUI_SAMPLER, COMFYUI_SCHEDULER, COMFYUI_SD3, + COMFYUI_FLUX, + COMFYUI_FLUX_WEIGHT_DTYPE, + COMFYUI_FLUX_FP8_CLIP, IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_KEY, IMAGE_GENERATION_MODEL, @@ -85,6 +88,9 @@ app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER app.state.config.COMFYUI_SD3 = COMFYUI_SD3 +app.state.config.COMFYUI_FLUX = COMFYUI_FLUX +app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE +app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP def get_automatic1111_api_auth(): @@ -497,6 +503,15 @@ async def image_generations( if app.state.config.COMFYUI_SD3 is not None: data["sd3"] = app.state.config.COMFYUI_SD3 + if app.state.config.COMFYUI_FLUX is not None: + data["flux"] = app.state.config.COMFYUI_FLUX + + if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None: + data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE + + if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None: + data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP + data = ImageGenerationPayload(**data) res = comfyui_generate_image( diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index 599b1f3379..6c37f0c497 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -125,6 +125,135 @@ COMFYUI_DEFAULT_PROMPT = """ } """ +FLUX_DEFAULT_PROMPT = """ +{ + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage" + }, + "6": { + "inputs": { + "text": "Input Text Here", + "clip": [ + "11", + 0 + ] + }, + "class_type": "CLIPTextEncode" + }, + "8": { + "inputs": { + "samples": [ + "13", + 0 + ], + "vae": [ + "10", + 0 + ] + }, + "class_type": "VAEDecode" + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage" + }, + "10": { + "inputs": { + "vae_name": "ae.sft" + }, + "class_type": "VAELoader" + }, + "11": { + "inputs": { + "clip_name1": "clip_l.safetensors", + "clip_name2": "t5xxl_fp16.safetensors", + "type": "flux" + }, + "class_type": "DualCLIPLoader" + }, + "12": { + "inputs": { + "unet_name": "flux1-dev.sft", + "weight_dtype": "default" + }, + "class_type": "UNETLoader" + }, + "13": { + "inputs": { + "noise": [ + "25", + 0 + ], + "guider": [ + "22", + 0 + ], + "sampler": [ + "16", + 0 + ], + "sigmas": [ + "17", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "SamplerCustomAdvanced" + }, + "16": { + "inputs": { + "sampler_name": "euler" + }, + "class_type": "KSamplerSelect" + }, + "17": { + "inputs": { + "scheduler": "simple", + "steps": 20, + "denoise": 1, + "model": [ + "12", + 0 + ] + }, + "class_type": "BasicScheduler" + }, + "22": { + "inputs": { + "model": [ + "12", + 0 + ], + "conditioning": [ + "6", + 0 + ] + }, + "class_type": "BasicGuider" + }, + "25": { + "inputs": { + "noise_seed": 778937779713005 + }, + "class_type": "RandomNoise" + } +} +""" + def queue_prompt(prompt, client_id, base_url): log.info("queue_prompt") @@ -194,6 +323,9 @@ class ImageGenerationPayload(BaseModel): sampler: Optional[str] = None scheduler: Optional[str] = None sd3: Optional[bool] = None + flux: Optional[bool] = None + flux_weight_dtype: Optional[str] = None + flux_fp8_clip: Optional[bool] = None def comfyui_generate_image( @@ -215,21 +347,46 @@ def comfyui_generate_image( if payload.sd3: comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage" + if payload.steps: + comfyui_prompt["3"]["inputs"]["steps"] = payload.steps + comfyui_prompt["4"]["inputs"]["ckpt_name"] = model + comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt + comfyui_prompt["3"]["inputs"]["seed"] = ( + payload.seed if payload.seed else random.randint(0, 18446744073709551614) + ) + + # as Flux uses a completely different workflow, we must treat it specially + if payload.flux: + comfyui_prompt = json.loads(FLUX_DEFAULT_PROMPT) + comfyui_prompt["12"]["inputs"]["unet_name"] = model + comfyui_prompt["25"]["inputs"]["noise_seed"] = ( + payload.seed if payload.seed else random.randint(0, 18446744073709551614) + ) + + if payload.sampler: + comfyui_prompt["16"]["inputs"]["sampler_name"] = payload.sampler + + if payload.steps: + comfyui_prompt["17"]["inputs"]["steps"] = payload.steps + + if payload.scheduler: + comfyui_prompt["17"]["inputs"]["scheduler"] = payload.scheduler + + if payload.flux_weight_dtype: + comfyui_prompt["12"]["inputs"]["weight_dtype"] = payload.flux_weight_dtype + + if payload.flux_fp8_clip: + comfyui_prompt["11"]["inputs"][ + "clip_name2" + ] = "t5xxl_fp8_e4m3fn.safetensors" + comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n comfyui_prompt["5"]["inputs"]["width"] = payload.width comfyui_prompt["5"]["inputs"]["height"] = payload.height # set the text prompt for our positive CLIPTextEncode comfyui_prompt["6"]["inputs"]["text"] = payload.prompt - comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt - - if payload.steps: - comfyui_prompt["3"]["inputs"]["steps"] = payload.steps - - comfyui_prompt["3"]["inputs"]["seed"] = ( - payload.seed if payload.seed else random.randint(0, 18446744073709551614) - ) try: ws = websocket.WebSocket() diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 0a36d4c2be..442d99ff26 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -805,7 +805,7 @@ async def generate_chat_completion( ) if ( - model_info.params.get("temperature", None) + model_info.params.get("temperature", None) is not None and payload["options"].get("temperature") is None ): payload["options"]["temperature"] = model_info.params.get( @@ -813,7 +813,7 @@ async def generate_chat_completion( ) if ( - model_info.params.get("seed", None) + model_info.params.get("seed", None) is not None and payload["options"].get("seed") is None ): payload["options"]["seed"] = model_info.params.get("seed", None) @@ -857,6 +857,12 @@ async def generate_chat_completion( ): payload["options"]["top_p"] = model_info.params.get("top_p", None) + if ( + model_info.params.get("min_p", None) + and payload["options"].get("min_p") is None + ): + payload["options"]["min_p"] = model_info.params.get("min_p", None) + if ( model_info.params.get("use_mmap", None) and payload["options"].get("use_mmap") is None diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 6c2906095a..c712709a5c 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -372,7 +372,7 @@ async def generate_chat_completion( if model_info.params: if ( - model_info.params.get("temperature", None) + model_info.params.get("temperature", None) is not None and payload.get("temperature") is None ): payload["temperature"] = float(model_info.params.get("temperature")) @@ -394,7 +394,10 @@ async def generate_chat_completion( model_info.params.get("frequency_penalty", None) ) - if model_info.params.get("seed", None) and payload.get("seed") is None: + if ( + model_info.params.get("seed", None) is not None + and payload.get("seed") is None + ): payload["seed"] = model_info.params.get("seed", None) if model_info.params.get("stop", None) and payload.get("stop") is None: @@ -459,6 +462,9 @@ async def generate_chat_completion( headers = {} headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" + if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: + headers["HTTP-Referer"] = "https://openwebui.com/" + headers["X-Title"] = "Open WebUI" r = None session = None diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 8631846ec7..dc6b8830ef 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1099,6 +1099,13 @@ def get_loader(filename: str, file_content_type: str, file_path: str): "vue", "svelte", "msg", + "ex", + "exs", + "erl", + "tsx", + "jsx", + "hs", + "lhs", ] if ( diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 18ce7a6072..1d98d37ff1 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -52,7 +52,6 @@ async def user_join(sid, data): user = Users.get_user_by_id(data["id"]) if user: - SESSION_POOL[sid] = user.id if user.id in USER_POOL: USER_POOL[user.id].append(sid) @@ -80,7 +79,6 @@ def get_models_in_use(): @sio.on("usage") async def usage(sid, data): - model_id = data["model"] # Cancel previous callback if there is one @@ -139,7 +137,7 @@ async def disconnect(sid): print(f"Unknown session ID {sid} disconnected") -async def get_event_emitter(request_info): +def get_event_emitter(request_info): async def __event_emitter__(event_data): await sio.emit( "chat-events", @@ -154,7 +152,7 @@ async def get_event_emitter(request_info): return __event_emitter__ -async def get_event_call(request_info): +def get_event_call(request_info): async def __event_call__(event_data): response = await sio.call( "chat-events", diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 570cad9f19..972562a04d 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -1,9 +1,6 @@ -from fastapi import FastAPI, Depends -from fastapi.routing import APIRoute +from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.sessions import SessionMiddleware -from sqlalchemy.orm import Session from apps.webui.routers import ( auths, users, @@ -22,12 +19,15 @@ from apps.webui.models.functions import Functions from apps.webui.models.models import Models from apps.webui.utils import load_function_module_by_id -from utils.misc import stream_message_template +from utils.misc import ( + openai_chat_chunk_message_template, + openai_chat_completion_message_template, + add_or_update_system_message, +) from utils.task import prompt_template from config import ( - WEBUI_BUILD_HASH, SHOW_ADMIN_DETAILS, ADMIN_EMAIL, WEBUI_AUTH, @@ -35,6 +35,7 @@ from config import ( DEFAULT_PROMPT_SUGGESTIONS, DEFAULT_USER_ROLE, ENABLE_SIGNUP, + ENABLE_LOGIN_FORM, USER_PERMISSIONS, WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, @@ -50,11 +51,9 @@ from config import ( from apps.socket.main import get_event_call, get_event_emitter import inspect -import uuid -import time import json -from typing import Iterator, Generator, Optional +from typing import Iterator, Generator, AsyncGenerator from pydantic import BaseModel app = FastAPI() @@ -64,6 +63,7 @@ origins = ["*"] app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER @@ -125,60 +125,58 @@ async def get_status(): } +def get_function_module(pipe_id: str): + # Check if function is already loaded + if pipe_id not in app.state.FUNCTIONS: + function_module, _, _ = load_function_module_by_id(pipe_id) + app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + return function_module + + async def get_pipe_models(): pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: - # Check if function is already loaded - if pipe.id not in app.state.FUNCTIONS: - function_module, function_type, frontmatter = load_function_module_by_id( - pipe.id - ) - app.state.FUNCTIONS[pipe.id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe.id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe.id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + function_module = get_function_module(pipe.id) # Check if function is a manifold - if hasattr(function_module, "type"): - if function_module.type == "manifold": - manifold_pipes = [] + if hasattr(function_module, "pipes"): + manifold_pipes = [] - # Check if pipes is a function or a list - if callable(function_module.pipes): - manifold_pipes = function_module.pipes() - else: - manifold_pipes = function_module.pipes + # Check if pipes is a function or a list + if callable(function_module.pipes): + manifold_pipes = function_module.pipes() + else: + manifold_pipes = function_module.pipes - for p in manifold_pipes: - manifold_pipe_id = f'{pipe.id}.{p["id"]}' - manifold_pipe_name = p["name"] + for p in manifold_pipes: + manifold_pipe_id = f'{pipe.id}.{p["id"]}' + manifold_pipe_name = p["name"] - if hasattr(function_module, "name"): - manifold_pipe_name = ( - f"{function_module.name}{manifold_pipe_name}" - ) + if hasattr(function_module, "name"): + manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}" - pipe_flag = {"type": pipe.type} - if hasattr(function_module, "ChatValves"): - pipe_flag["valves_spec"] = function_module.ChatValves.schema() + pipe_flag = {"type": pipe.type} + if hasattr(function_module, "ChatValves"): + pipe_flag["valves_spec"] = function_module.ChatValves.schema() - pipe_models.append( - { - "id": manifold_pipe_id, - "name": manifold_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) + pipe_models.append( + { + "id": manifold_pipe_id, + "name": manifold_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) else: pipe_flag = {"type": "pipe"} if hasattr(function_module, "ChatValves"): @@ -198,262 +196,211 @@ async def get_pipe_models(): return pipe_models +async def execute_pipe(pipe, params): + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + + +async def get_message_content(res: str | Generator | AsyncGenerator) -> str: + if isinstance(res, str): + return res + if isinstance(res, Generator): + return "".join(map(str, res)) + if isinstance(res, AsyncGenerator): + return "".join([str(stream) async for stream in res]) + + +def process_line(form_data: dict, line): + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except Exception: + pass + + if line.startswith("data:"): + return f"{line}\n\n" + else: + line = openai_chat_chunk_message_template(form_data["model"], line) + return f"data: {json.dumps(line)}\n\n" + + +def get_pipe_id(form_data: dict) -> str: + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, _ = pipe_id.split(".", 1) + print(pipe_id) + return pipe_id + + +def get_function_params(function_module, form_data, user, extra_params={}): + pipe_id = get_pipe_id(form_data) + # Get the signature of the function + sig = inspect.signature(function_module.pipe) + params = {"body": form_data} + + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + ) + except Exception as e: + print(e) + + params["__user__"] = __user__ + return params + + +# inplace function: form_data is modified +def apply_model_params_to_body(params: dict, form_data: dict) -> dict: + if not params: + return form_data + + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } + + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + return form_data + + +# inplace function: form_data is modified +def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: + system = params.get("system", None) + if not system: + return form_data + + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } + else: + template_params = {} + system = prompt_template(system, **template_params) + form_data["messages"] = add_or_update_system_message( + system, form_data.get("messages", []) + ) + return form_data + + async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - - metadata = None - if "metadata" in form_data: - metadata = form_data["metadata"] - del form_data["metadata"] + metadata = form_data.pop("metadata", None) __event_emitter__ = None __event_call__ = None __task__ = None if metadata: - if ( - metadata.get("session_id") - and metadata.get("chat_id") - and metadata.get("message_id") - ): - __event_emitter__ = await get_event_emitter(metadata) - __event_call__ = await get_event_call(metadata) - - if metadata.get("task"): - __task__ = metadata.get("task") + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() + form_data = apply_model_params_to_body(params, form_data) + form_data = apply_model_system_prompt_to_body(params, form_data, user) - if model_info.params: - if model_info.params.get("temperature", None) is not None: - form_data["temperature"] = float(model_info.params.get("temperature")) + pipe_id = get_pipe_id(form_data) + function_module = get_function_module(pipe_id) - if model_info.params.get("top_p", None): - form_data["top_p"] = int(model_info.params.get("top_p", None)) + pipe = function_module.pipe + params = get_function_params( + function_module, + form_data, + user, + { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + }, + ) - if model_info.params.get("max_tokens", None): - form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) - - if model_info.params.get("frequency_penalty", None): - form_data["frequency_penalty"] = int( - model_info.params.get("frequency_penalty", None) - ) - - if model_info.params.get("seed", None): - form_data["seed"] = model_info.params.get("seed", None) - - if model_info.params.get("stop", None): - form_data["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - system = model_info.params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - # Check if the payload already has a system message - # If not, add a system message to the payload - if form_data.get("messages"): - for message in form_data["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - form_data["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) - - else: - pass - - async def job(): - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, sub_pipe_id = pipe_id.split(".", 1) - print(pipe_id) - - # Check if function is already loaded - if pipe_id not in app.state.FUNCTIONS: - function_module, function_type, frontmatter = load_function_module_by_id( - pipe_id - ) - app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - pipe = function_module.pipe - - # Get the signature of the function - sig = inspect.signature(pipe) - params = {"body": form_data} - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + if form_data["stream"]: + async def stream_content(): try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - ) - except Exception as e: - print(e) + res = await execute_pipe(pipe, params) - params = {**params, "__user__": __user__} - - if "__event_emitter__" in sig.parameters: - params = {**params, "__event_emitter__": __event_emitter__} - - if "__event_call__" in sig.parameters: - params = {**params, "__event_call__": __event_call__} - - if "__task__" in sig.parameters: - params = {**params, "__task__": __task__} - - if form_data["stream"]: - - async def stream_content(): - try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) - - # Directly return if the response is a StreamingResponse - if isinstance(res, StreamingResponse): - async for data in res.body_iterator: - yield data - return - if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" - return - - except Exception as e: - print(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + # Directly return if the response is a StreamingResponse + if isinstance(res, StreamingResponse): + async for data in res.body_iterator: + yield data + return + if isinstance(res, dict): + yield f"data: {json.dumps(res)}\n\n" return - if isinstance(res, str): - message = stream_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" - - try: - line = line.decode("utf-8") - except: - pass - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data["model"], line) - yield f"data: {json.dumps(line)}\n\n" - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - - try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) - - if isinstance(res, StreamingResponse): - return res except Exception as e: print(f"Error: {e}") - return {"error": {"detail": str(e)}} + yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + return - if isinstance(res, dict): - return res - elif isinstance(res, BaseModel): - return res.model_dump() - else: - message = "" - if isinstance(res, str): - message = res - if isinstance(res, Generator): - for stream in res: - message = f"{message}{stream}" + if isinstance(res, str): + message = openai_chat_chunk_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" - return { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } + if isinstance(res, Iterator): + for line in res: + yield process_line(form_data, line) - return await job() + if isinstance(res, AsyncGenerator): + async for line in res: + yield process_line(form_data, line) + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = openai_chat_chunk_message_template( + form_data["model"], "" + ) + finish_message["choices"][0]["finish_reason"] = "stop" + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + try: + res = await execute_pipe(pipe, params) + + except Exception as e: + print(f"Error: {e}") + return {"error": {"detail": str(e)}} + + if isinstance(res, StreamingResponse) or isinstance(res, dict): + return res + if isinstance(res, BaseModel): + return res.model_dump() + + message = await get_message_content(res) + return openai_chat_completion_message_template(form_data["model"], message) diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index c03abb2338..abde4f2b31 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -245,6 +245,38 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chat_title_id_list_by_user_id( + self, + user_id: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 50, + ) -> List[ChatTitleIdResponse]: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # limit cols + .with_entities( + Chat.id, Chat.title, Chat.updated_at, Chat.created_at + ).all() + ) + # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. + return [ + ChatTitleIdResponse.model_validate( + { + "id": chat[0], + "title": chat[1], + "updated_at": chat[2], + "created_at": chat[3], + } + ) + for chat in all_chats + ] + def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 3b128c7d67..8277d1d0ba 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -1,13 +1,11 @@ -import json import logging -from typing import Optional +from typing import Optional, List from pydantic import BaseModel, ConfigDict -from sqlalchemy import String, Column, BigInteger, Text +from sqlalchemy import Column, BigInteger, Text from apps.webui.internal.db import Base, JSONField, get_db -from typing import List, Union, Optional from config import SRC_LOG_LEVELS import time @@ -113,7 +111,6 @@ class ModelForm(BaseModel): class ModelsTable: - def insert_new_model( self, form_data: ModelForm, user_id: str ) -> Optional[ModelModel]: @@ -126,9 +123,7 @@ class ModelsTable: } ) try: - with get_db() as db: - result = Model(**model.model_dump()) db.add(result) db.commit() @@ -144,13 +139,11 @@ class ModelsTable: def get_all_models(self) -> List[ModelModel]: with get_db() as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: - model = db.get(Model, id) return ModelModel.model_validate(model) except: @@ -178,7 +171,6 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: with get_db() as db: - db.query(Model).filter_by(id=id).delete() db.commit() diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index d3ccb9cce5..80308a451b 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -45,7 +45,7 @@ router = APIRouter() async def get_session_user_chat_list( user=Depends(get_verified_user), skip: int = 0, limit: int = 50 ): - return Chats.get_chat_list_by_user_id(user.id, skip, limit) + return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit) ############################ diff --git a/backend/apps/webui/routers/utils.py b/backend/apps/webui/routers/utils.py index 780ed6b43e..4ffe748b0b 100644 --- a/backend/apps/webui/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -1,3 +1,6 @@ +from pathlib import Path +import site + from fastapi import APIRouter, UploadFile, File, Response from fastapi import Depends, HTTPException, status from starlette.responses import StreamingResponse, FileResponse @@ -64,8 +67,18 @@ async def download_chat_as_pdf( pdf = FPDF() pdf.add_page() - STATIC_DIR = "./static" - FONTS_DIR = f"{STATIC_DIR}/fonts" + # When running in docker, workdir is /app/backend, so fonts is in /app/backend/static/fonts + FONTS_DIR = Path("./static/fonts") + + # Non Docker Installation + + # When running using `pip install` the static directory is in the site packages. + if not FONTS_DIR.exists(): + FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts" + # When running using `pip install -e .` the static directory is in the site packages. + # This path only works if `open-webui serve` is run from the root of this project. + if not FONTS_DIR.exists(): + FONTS_DIR = Path("./backend/static/fonts") pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf") pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf") diff --git a/backend/config.py b/backend/config.py index fe68eee34c..e976b226df 100644 --- a/backend/config.py +++ b/backend/config.py @@ -77,6 +77,16 @@ for source in log_sources: log.setLevel(SRC_LOG_LEVELS["CONFIG"]) + +class EndpointFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return record.getMessage().find("/health") == -1 + + +# Filter out /endpoint +logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) + + WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") if WEBUI_NAME != "Open WebUI": WEBUI_NAME += " (Open WebUI)" @@ -339,6 +349,12 @@ GOOGLE_OAUTH_SCOPE = PersistentConfig( os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), ) +GOOGLE_REDIRECT_URI = PersistentConfig( + "GOOGLE_REDIRECT_URI", + "oauth.google.redirect_uri", + os.environ.get("GOOGLE_REDIRECT_URI", ""), +) + MICROSOFT_CLIENT_ID = PersistentConfig( "MICROSOFT_CLIENT_ID", "oauth.microsoft.client_id", @@ -363,6 +379,12 @@ MICROSOFT_OAUTH_SCOPE = PersistentConfig( os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), ) +MICROSOFT_REDIRECT_URI = PersistentConfig( + "MICROSOFT_REDIRECT_URI", + "oauth.microsoft.redirect_uri", + os.environ.get("MICROSOFT_REDIRECT_URI", ""), +) + OAUTH_CLIENT_ID = PersistentConfig( "OAUTH_CLIENT_ID", "oauth.oidc.client_id", @@ -381,6 +403,12 @@ OPENID_PROVIDER_URL = PersistentConfig( os.environ.get("OPENID_PROVIDER_URL", ""), ) +OPENID_REDIRECT_URI = PersistentConfig( + "OPENID_REDIRECT_URI", + "oauth.oidc.redirect_uri", + os.environ.get("OPENID_REDIRECT_URI", ""), +) + OAUTH_SCOPES = PersistentConfig( "OAUTH_SCOPES", "oauth.oidc.scopes", @@ -414,6 +442,7 @@ def load_oauth_providers(): "client_secret": GOOGLE_CLIENT_SECRET.value, "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", "scope": GOOGLE_OAUTH_SCOPE.value, + "redirect_uri": GOOGLE_REDIRECT_URI.value, } if ( @@ -426,6 +455,7 @@ def load_oauth_providers(): "client_secret": MICROSOFT_CLIENT_SECRET.value, "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", "scope": MICROSOFT_OAUTH_SCOPE.value, + "redirect_uri": MICROSOFT_REDIRECT_URI.value, } if ( @@ -439,6 +469,7 @@ def load_oauth_providers(): "server_metadata_url": OPENID_PROVIDER_URL.value, "scope": OAUTH_SCOPES.value, "name": OAUTH_PROVIDER_NAME.value, + "redirect_uri": OPENID_REDIRECT_URI.value, } @@ -709,6 +740,12 @@ ENABLE_SIGNUP = PersistentConfig( ), ) +ENABLE_LOGIN_FORM = PersistentConfig( + "ENABLE_LOGIN_FORM", + "ui.ENABLE_LOGIN_FORM", + os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", +) + DEFAULT_LOCALE = PersistentConfig( "DEFAULT_LOCALE", "ui.default_locale", @@ -1265,6 +1302,24 @@ COMFYUI_SD3 = PersistentConfig( os.environ.get("COMFYUI_SD3", "").lower() == "true", ) +COMFYUI_FLUX = PersistentConfig( + "COMFYUI_FLUX", + "image_generation.comfyui.flux", + os.environ.get("COMFYUI_FLUX", "").lower() == "true", +) + +COMFYUI_FLUX_WEIGHT_DTYPE = PersistentConfig( + "COMFYUI_FLUX_WEIGHT_DTYPE", + "image_generation.comfyui.flux_weight_dtype", + os.getenv("COMFYUI_FLUX_WEIGHT_DTYPE", ""), +) + +COMFYUI_FLUX_FP8_CLIP = PersistentConfig( + "COMFYUI_FLUX_FP8_CLIP", + "image_generation.comfyui.flux_fp8_clip", + os.getenv("COMFYUI_FLUX_FP8_CLIP", ""), +) + IMAGES_OPENAI_API_BASE_URL = PersistentConfig( "IMAGES_OPENAI_API_BASE_URL", "image_generation.openai.api_base_url", @@ -1329,6 +1384,11 @@ AUDIO_TTS_OPENAI_API_KEY = PersistentConfig( os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), ) +AUDIO_TTS_API_KEY = PersistentConfig( + "AUDIO_TTS_API_KEY", + "audio.tts.api_key", + os.getenv("AUDIO_TTS_API_KEY", ""), +) AUDIO_TTS_ENGINE = PersistentConfig( "AUDIO_TTS_ENGINE", diff --git a/backend/main.py b/backend/main.py index 62f07a868c..a7dd8bc23b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -13,8 +13,6 @@ import aiohttp import requests import mimetypes import shutil -import os -import uuid import inspect from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form @@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import StreamingResponse, Response, RedirectResponse -from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call +from apps.socket.main import app as socket_app, get_event_emitter, get_event_call from apps.ollama.main import ( app as ollama_app, get_all_models as get_ollama_models, @@ -79,6 +77,7 @@ from utils.task import ( from utils.misc import ( get_last_user_message, add_or_update_system_message, + prepend_to_first_user_message_content, parse_duration, ) @@ -618,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) - # Extract valves from the request body - valves = None - if "valves" in body: - valves = body["valves"] - del body["valves"] + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "valves": body.pop("valves", None), + } - # Extract session_id, chat_id and message_id from the request body - session_id = None - if "session_id" in body: - session_id = body["session_id"] - del body["session_id"] - chat_id = None - if "chat_id" in body: - chat_id = body["chat_id"] - del body["chat_id"] - message_id = None - if "id" in body: - message_id = body["id"] - del body["id"] - - __event_emitter__ = await get_event_emitter( - {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} - ) - __event_call__ = await get_event_call( - {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} - ) + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) # Initialize data_items to store additional data to be sent to the client data_items = [] @@ -686,24 +668,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(contexts) > 0: context_string = "/n".join(contexts).strip() prompt = get_last_user_message(body["messages"]) - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + else: + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) # If there are citations, add them to the data_items if len(citations) > 0: data_items.append({"citations": citations}) - body["metadata"] = { - "session_id": session_id, - "chat_id": chat_id, - "message_id": message_id, - "valves": valves, - } - + body["metadata"] = metadata modified_body_bytes = json.dumps(body).encode("utf-8") # Replace the request body with the modified one request._body = modified_body_bytes @@ -979,32 +966,15 @@ async def get_all_models(): model["name"] = custom_model.name model["info"] = custom_model.model_dump() - action_ids = [] + global_action_ids + action_ids = [] if "info" in model and "meta" in model["info"]: action_ids.extend(model["info"]["meta"].get("actionIds", [])) - action_ids = list(set(action_ids)) - action_ids = [ - action_id - for action_id in action_ids - if action_id in enabled_action_ids - ] - - model["actions"] = [] - for action_id in action_ids: - action = Functions.get_function_by_id(action_id) - model["actions"].append( - { - "id": action_id, - "name": action.name, - "description": action.meta.description, - "icon_url": action.meta.manifest.get("icon_url", None), - } - ) + model["action_ids"] = action_ids else: owned_by = "openai" pipe = None - actions = [] + action_ids = [] for model in models: if ( @@ -1015,26 +985,8 @@ async def get_all_models(): if "pipe" in model: pipe = model["pipe"] - action_ids = [] + global_action_ids if "info" in model and "meta" in model["info"]: action_ids.extend(model["info"]["meta"].get("actionIds", [])) - action_ids = list(set(action_ids)) - action_ids = [ - action_id - for action_id in action_ids - if action_id in enabled_action_ids - ] - - actions = [ - { - "id": action_id, - "name": Functions.get_function_by_id(action_id).name, - "description": Functions.get_function_by_id( - action_id - ).meta.description, - } - for action_id in action_ids - ] break models.append( @@ -1047,10 +999,59 @@ async def get_all_models(): "info": custom_model.model_dump(), "preset": True, **({"pipe": pipe} if pipe is not None else {}), - "actions": actions, + "action_ids": action_ids, } ) + for model in models: + action_ids = [] + if "action_ids" in model: + action_ids = model["action_ids"] + del model["action_ids"] + + action_ids = action_ids + global_action_ids + action_ids = list(set(action_ids)) + action_ids = [ + action_id for action_id in action_ids if action_id in enabled_action_ids + ] + + model["actions"] = [] + for action_id in action_ids: + action = Functions.get_function_by_id(action_id) + + if action_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + webui_app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "actions"): + actions = function_module.actions + model["actions"].extend( + [ + { + "id": f"{action_id}.{_action['id']}", + "name": _action.get( + "name", f"{action.name} ({_action['id']})" + ), + "description": action.meta.description, + "icon_url": _action.get( + "icon_url", action.meta.manifest.get("icon_url", None) + ), + } + for _action in actions + ] + ) + else: + model["actions"].append( + { + "id": action_id, + "name": action.name, + "description": action.meta.description, + "icon_url": action.meta.manifest.get("icon_url", None), + } + ) + app.state.MODELS = {model["id"]: model for model in models} webui_app.state.MODELS = app.state.MODELS @@ -1165,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): status_code=r.status_code, content=res, ) - except: + except Exception: pass else: pass - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1179,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): } ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1284,9 +1285,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): @app.post("/api/chat/actions/{action_id}") -async def chat_completed( - action_id: str, form_data: dict, user=Depends(get_verified_user) -): +async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + action = Functions.get_function_by_id(action_id) if not action: raise HTTPException( @@ -1303,14 +1307,14 @@ async def chat_completed( ) model = app.state.MODELS[model_id] - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( { "chat_id": data["chat_id"], "message_id": data["id"], "session_id": data["session_id"], } ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1339,7 +1343,7 @@ async def chat_completed( # Extra parameters to be passed to the function extra_params = { "__model__": model, - "__id__": action_id, + "__id__": sub_action_id if sub_action_id is not None else action_id, "__event_emitter__": __event_emitter__, "__event_call__": __event_call__, } @@ -1739,7 +1743,6 @@ class AddPipelineForm(BaseModel): @app.post("/api/pipelines/add") async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): - r = None try: urlIdx = form_data.urlIdx @@ -1782,7 +1785,6 @@ class DeletePipelineForm(BaseModel): @app.delete("/api/pipelines/delete") async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): - r = None try: urlIdx = form_data.urlIdx @@ -1860,7 +1862,6 @@ async def get_pipeline_valves( models = await get_all_models() r = None try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] @@ -1995,6 +1996,7 @@ async def get_app_config(): "auth": WEBUI_AUTH, "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "enable_signup": webui_app.state.config.ENABLE_SIGNUP, + "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_image_generation": images_app.state.config.ENABLED, "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, @@ -2111,6 +2113,7 @@ for provider_name, provider_config in OAUTH_PROVIDERS.items(): client_kwargs={ "scope": provider_config["scope"], }, + redirect_uri=provider_config["redirect_uri"], ) # SessionMiddleware is used by authlib for oauth @@ -2128,7 +2131,10 @@ if len(OAUTH_PROVIDERS) > 0: async def oauth_login(provider: str, request: Request): if provider not in OAUTH_PROVIDERS: raise HTTPException(404) - redirect_uri = request.url_for("oauth_callback", provider=provider) + # If the provider has a custom redirect URL, use that, otherwise automatically generate one + redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( + "oauth_callback", provider=provider + ) return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) diff --git a/backend/requirements.txt b/backend/requirements.txt index 61185796db..8b12854a08 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -30,11 +30,11 @@ APScheduler==3.10.4 # AI libraries openai anthropic -google-generativeai==0.5.4 +google-generativeai==0.7.2 tiktoken -langchain==0.2.6 -langchain-community==0.2.6 +langchain==0.2.11 +langchain-community==0.2.10 langchain-chroma==0.1.2 fake-useragent==1.5.1 @@ -43,7 +43,7 @@ sentence-transformers==3.0.1 pypdf==4.2.0 docx2txt==0.8 python-pptx==0.6.23 -unstructured==0.14.10 +unstructured==0.15.0 Markdown==3.6 pypandoc==1.13 pandas==2.2.2 @@ -54,7 +54,7 @@ validators==0.28.1 psutil opencv-python-headless==4.10.0.84 -rapidocr-onnxruntime==1.3.22 +rapidocr-onnxruntime==1.3.24 fpdf2==2.7.9 rank-bm25==0.2.2 @@ -65,13 +65,13 @@ PyJWT[crypto]==2.8.0 authlib==1.3.1 black==24.4.2 -langfuse==2.38.0 +langfuse==2.39.2 youtube-transcript-api==0.6.2 pytube==15.0.0 extract_msg pydub -duckduckgo-search~=6.1.12 +duckduckgo-search~=6.2.1 ## Tests docker~=7.1.0 diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5a05f167d5..c4e2eda6f0 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,6 +1,5 @@ from pathlib import Path import hashlib -import json import re from datetime import timedelta from typing import Optional, List, Tuple @@ -8,37 +7,39 @@ import uuid import time -def get_last_user_message_item(messages: List[dict]) -> str: +def get_last_user_message_item(messages: List[dict]) -> Optional[dict]: for message in reversed(messages): if message["role"] == "user": return message return None -def get_last_user_message(messages: List[dict]) -> str: - message = get_last_user_message_item(messages) - - if message is not None: - if isinstance(message["content"], list): - for item in message["content"]: - if item["type"] == "text": - return item["text"] +def get_content_from_message(message: dict) -> Optional[str]: + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + else: return message["content"] return None -def get_last_assistant_message(messages: List[dict]) -> str: +def get_last_user_message(messages: List[dict]) -> Optional[str]: + message = get_last_user_message_item(messages) + if message is None: + return None + + return get_content_from_message(message) + + +def get_last_assistant_message(messages: List[dict]) -> Optional[str]: for message in reversed(messages): if message["role"] == "assistant": - if isinstance(message["content"], list): - for item in message["content"]: - if item["type"] == "text": - return item["text"] - return message["content"] + return get_content_from_message(message) return None -def get_system_message(messages: List[dict]) -> dict: +def get_system_message(messages: List[dict]) -> Optional[dict]: for message in messages: if message["role"] == "system": return message @@ -49,10 +50,25 @@ def remove_system_message(messages: List[dict]) -> List[dict]: return [message for message in messages if message["role"] != "system"] -def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: +def pop_system_message(messages: List[dict]) -> Tuple[Optional[dict], List[dict]]: return get_system_message(messages), remove_system_message(messages) +def prepend_to_first_user_message_content( + content: str, messages: List[dict] +) -> List[dict]: + for message in messages: + if message["role"] == "user": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + item["text"] = f"{content}\n{item['text']}" + else: + message["content"] = f"{content}\n{message['content']}" + break + return messages + + def add_or_update_system_message(content: str, messages: List[dict]): """ Adds a new system message at the beginning of the messages list @@ -72,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]): return messages -def stream_message_template(model: str, message: str): +def openai_chat_message_template(model: str): return { "id": f"{model}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", "created": int(time.time()), "model": model, - "choices": [ - { - "index": 0, - "delta": {"content": message}, - "logprobs": None, - "finish_reason": None, - } - ], + "choices": [{"index": 0, "logprobs": None, "finish_reason": None}], } +def openai_chat_chunk_message_template(model: str, message: str): + template = openai_chat_message_template(model) + template["object"] = "chat.completion.chunk" + template["choices"][0]["delta"] = {"content": message} + return template + + +def openai_chat_completion_message_template(model: str, message: str): + template = openai_chat_message_template(model) + template["object"] = "chat.completion" + template["choices"][0]["message"] = {"content": message, "role": "assistant"} + template["choices"][0]["finish_reason"] = "stop" + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters @@ -159,7 +181,7 @@ def extract_folders_after_data_docs(path): tags = [] folders = parts[index_docs:-1] - for idx, part in enumerate(folders): + for idx, _ in enumerate(folders): tags.append("/".join(folders[: idx + 1])) return tags @@ -255,11 +277,11 @@ def parse_ollama_modelfile(model_text): value = param_match.group(1) try: - if param_type == int: + if param_type is int: value = int(value) - elif param_type == float: + elif param_type is float: value = float(value) - elif param_type == bool: + elif param_type is bool: value = value.lower() == "true" except Exception as e: print(e) diff --git a/cypress/data/example-doc.txt b/cypress/data/example-doc.txt new file mode 100644 index 0000000000..d4f6f455ed --- /dev/null +++ b/cypress/data/example-doc.txt @@ -0,0 +1,9 @@ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Pellentesque elit eget gravida cum sociis natoque. Morbi tristique senectus et netus et malesuada. Sapien nec sagittis aliquam malesuada bibendum. Amet consectetur adipiscing elit duis tristique sollicitudin. Non pulvinar neque laoreet suspendisse interdum consectetur libero. Arcu cursus vitae congue mauris rhoncus aenean vel elit scelerisque. Nec feugiat nisl pretium fusce id velit. Imperdiet proin fermentum leo vel. Arcu dui vivamus arcu felis bibendum ut tristique et egestas. Pellentesque sit amet porttitor eget dolor morbi non arcu risus. Egestas tellus rutrum tellus pellentesque eu tincidunt tortor aliquam. Et ultrices neque ornare aenean euismod. + +Enim nulla aliquet porttitor lacus luctus accumsan tortor posuere ac. Viverra nibh cras pulvinar mattis nunc. Lacinia at quis risus sed vulputate. Ac tortor vitae purus faucibus ornare suspendisse sed nisi lacus. Bibendum arcu vitae elementum curabitur vitae nunc. Consectetur adipiscing elit duis tristique sollicitudin nibh sit amet commodo. Velit egestas dui id ornare arcu odio ut. Et malesuada fames ac turpis egestas integer eget aliquet. Lacus suspendisse faucibus interdum posuere lorem ipsum dolor sit. Morbi tristique senectus et netus. Pretium viverra suspendisse potenti nullam ac tortor vitae. Parturient montes nascetur ridiculus mus mauris vitae. Quis viverra nibh cras pulvinar mattis nunc sed blandit libero. Euismod nisi porta lorem mollis aliquam ut porttitor leo. Mauris in aliquam sem fringilla ut morbi. Faucibus pulvinar elementum integer enim neque. Neque sodales ut etiam sit. Consectetur a erat nam at. + +Sed nisi lacus sed viverra tellus in hac habitasse. Proin sagittis nisl rhoncus mattis rhoncus. Risus commodo viverra maecenas accumsan lacus. Morbi quis commodo odio aenean sed adipiscing. Mollis nunc sed id semper risus in. Ultricies mi eget mauris pharetra et ultrices neque. Amet luctus venenatis lectus magna fringilla urna porttitor rhoncus. Eget magna fermentum iaculis eu non diam phasellus. Id diam maecenas ultricies mi eget mauris pharetra et ultrices. Id donec ultrices tincidunt arcu non sodales. Sed cras ornare arcu dui vivamus arcu felis bibendum ut. Urna duis convallis convallis tellus id interdum velit. Rhoncus mattis rhoncus urna neque viverra justo nec. Purus semper eget duis at tellus at urna condimentum. Et odio pellentesque diam volutpat commodo sed egestas. Blandit volutpat maecenas volutpat blandit. In egestas erat imperdiet sed euismod nisi porta lorem mollis. Est ullamcorper eget nulla facilisi etiam dignissim. + +Justo nec ultrices dui sapien eget mi proin sed. Purus gravida quis blandit turpis cursus in hac. Placerat orci nulla pellentesque dignissim enim sit. Morbi tristique senectus et netus et malesuada fames ac. Consequat mauris nunc congue nisi. Eu lobortis elementum nibh tellus molestie nunc non blandit. Viverra justo nec ultrices dui. Morbi non arcu risus quis. Elementum sagittis vitae et leo duis. Lectus mauris ultrices eros in cursus. Neque laoreet suspendisse interdum consectetur. + +Facilisis gravida neque convallis a cras. Nisl rhoncus mattis rhoncus urna neque viverra justo. Faucibus purus in massa tempor. Lacus laoreet non curabitur gravida arcu ac tortor. Tincidunt eget nullam non nisi est sit amet. Ornare lectus sit amet est placerat in egestas. Sollicitudin tempor id eu nisl nunc mi. Scelerisque viverra mauris in aliquam sem fringilla ut. Ullamcorper sit amet risus nullam. Mauris rhoncus aenean vel elit scelerisque mauris pellentesque pulvinar. Velit euismod in pellentesque massa placerat duis ultricies lacus. Pharetra magna ac placerat vestibulum lectus mauris ultrices eros in. Lorem ipsum dolor sit amet. Sit amet mauris commodo quis imperdiet. Quam pellentesque nec nam aliquam sem et tortor. Amet nisl purus in mollis nunc. Sed risus pretium quam vulputate dignissim suspendisse in est. Nisl condimentum id venenatis a condimentum. Velit euismod in pellentesque massa. Quam id leo in vitae turpis massa sed. diff --git a/cypress/e2e/documents.cy.ts b/cypress/e2e/documents.cy.ts new file mode 100644 index 0000000000..6ca14980d2 --- /dev/null +++ b/cypress/e2e/documents.cy.ts @@ -0,0 +1,46 @@ +// eslint-disable-next-line @typescript-eslint/triple-slash-reference +/// + +describe('Documents', () => { + const timestamp = Date.now(); + + before(() => { + cy.uploadTestDocument(timestamp); + }); + + after(() => { + cy.deleteTestDocument(timestamp); + }); + + context('Admin', () => { + beforeEach(() => { + // Login as the admin user + cy.loginAdmin(); + // Visit the home page + cy.visit('/workspace/documents'); + cy.get('button').contains('#cypress-test').click(); + }); + + it('can see documents', () => { + cy.get('div').contains(`document-test-initial-${timestamp}.txt`).should('have.length', 1); + }); + + it('can see edit button', () => { + cy.get('div') + .contains(`document-test-initial-${timestamp}.txt`) + .get("button[aria-label='Edit Doc']") + .should('exist'); + }); + + it('can see delete button', () => { + cy.get('div') + .contains(`document-test-initial-${timestamp}.txt`) + .get("button[aria-label='Delete Doc']") + .should('exist'); + }); + + it('can see upload button', () => { + cy.get("button[aria-label='Add Docs']").should('exist'); + }); + }); +}); diff --git a/cypress/support/e2e.ts b/cypress/support/e2e.ts index 1eedc98dfe..9847887333 100644 --- a/cypress/support/e2e.ts +++ b/cypress/support/e2e.ts @@ -1,4 +1,6 @@ /// +// eslint-disable-next-line @typescript-eslint/triple-slash-reference +/// export const adminUser = { name: 'Admin User', @@ -10,6 +12,9 @@ const login = (email: string, password: string) => { return cy.session( email, () => { + // Make sure to test against us english to have stable tests, + // regardless on local language preferences + localStorage.setItem('locale', 'en-US'); // Visit auth page cy.visit('/auth'); // Fill out the form @@ -68,6 +73,50 @@ Cypress.Commands.add('register', (name, email, password) => register(name, email Cypress.Commands.add('registerAdmin', () => registerAdmin()); Cypress.Commands.add('loginAdmin', () => loginAdmin()); +Cypress.Commands.add('uploadTestDocument', (suffix: any) => { + // Login as admin + cy.loginAdmin(); + // upload example document + cy.visit('/workspace/documents'); + // Create a document + cy.get("button[aria-label='Add Docs']").click(); + cy.readFile('cypress/data/example-doc.txt').then((text) => { + // select file + cy.get('#upload-doc-input').selectFile( + { + contents: Cypress.Buffer.from(text + Date.now()), + fileName: `document-test-initial-${suffix}.txt`, + mimeType: 'text/plain', + lastModified: Date.now() + }, + { + force: true + } + ); + // open tag input + cy.get("button[aria-label='Add Tag']").click(); + cy.get("input[placeholder='Add a tag']").type('cypress-test'); + cy.get("button[aria-label='Save Tag']").click(); + + // submit to upload + cy.get("button[type='submit']").click(); + + // wait for upload to finish + cy.get('button').contains('#cypress-test').should('exist'); + cy.get('div').contains(`document-test-initial-${suffix}.txt`).should('exist'); + }); +}); + +Cypress.Commands.add('deleteTestDocument', (suffix: any) => { + cy.loginAdmin(); + cy.visit('/workspace/documents'); + // clean up uploaded documents + cy.get('div') + .contains(`document-test-initial-${suffix}.txt`) + .find("button[aria-label='Delete Doc']") + .click(); +}); + before(() => { cy.registerAdmin(); }); diff --git a/cypress/support/index.d.ts b/cypress/support/index.d.ts index e6c69121a9..647db92115 100644 --- a/cypress/support/index.d.ts +++ b/cypress/support/index.d.ts @@ -7,5 +7,7 @@ declare namespace Cypress { register(name: string, email: string, password: string): Chainable; registerAdmin(): Chainable; loginAdmin(): Chainable; + uploadTestDocument(suffix: any): Chainable; + deleteTestDocument(suffix: any): Chainable; } } diff --git a/package-lock.json b/package-lock.json index cf04da5c62..f3e8d2e38d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.3.10", + "version": "0.3.11", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.10", + "version": "0.3.11", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -20,6 +20,7 @@ "dayjs": "^1.11.10", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", + "fuse.js": "^7.0.0", "highlight.js": "^11.9.0", "i18next": "^23.10.0", "i18next-browser-languagedetector": "^7.2.0", @@ -4820,6 +4821,14 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/fuse.js": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/fuse.js/-/fuse.js-7.0.0.tgz", + "integrity": "sha512-14F4hBIxqKvD4Zz/XjDc3y94mNZN6pRv3U13Udo0lNLCWRBUsrMv2xwcF/y/Z5sV6+FQW+/ow68cHpm4sunt8Q==", + "engines": { + "node": ">=10" + } + }, "node_modules/gc-hook": { "version": "0.3.1", "resolved": "https://registry.npmjs.org/gc-hook/-/gc-hook-0.3.1.tgz", diff --git a/package.json b/package.json index f7cc125982..6687cef755 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.10", + "version": "0.3.11", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -60,6 +60,7 @@ "dayjs": "^1.11.10", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", + "fuse.js": "^7.0.0", "highlight.js": "^11.9.0", "i18next": "^23.10.0", "i18next-browser-languagedetector": "^7.2.0", diff --git a/pyproject.toml b/pyproject.toml index f8e7295cf7..efce1158fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "peewee==3.17.5", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", - "PyMySQL==1.1.0", + "PyMySQL==1.1.1", "bcrypt==4.1.3", "boto3==1.34.110", @@ -33,7 +33,7 @@ dependencies = [ "google-generativeai==0.5.4", "langchain==0.2.0", - "langchain-community==0.2.0", + "langchain-community==0.2.9", "langchain-chroma==0.1.1", "fake-useragent==1.5.1", diff --git a/src/app.css b/src/app.css index c3388f1d36..69e107be98 100644 --- a/src/app.css +++ b/src/app.css @@ -154,3 +154,7 @@ input[type='number'] { .tippy-box[data-theme~='dark'] { @apply rounded-lg bg-gray-950 text-xs border border-gray-900 shadow-xl; } + +.password { + -webkit-text-security: disc; +} diff --git a/src/lib/apis/audio/index.ts b/src/lib/apis/audio/index.ts index 9716c552a7..af09af9907 100644 --- a/src/lib/apis/audio/index.ts +++ b/src/lib/apis/audio/index.ts @@ -131,3 +131,59 @@ export const synthesizeOpenAISpeech = async ( return res; }; + +export const getModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AUDIO_API_BASE_URL}/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getVoices = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AUDIO_API_BASE_URL}/voices`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte index c0b018124a..7c33005682 100644 --- a/src/lib/components/admin/Settings/Audio.svelte +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -1,13 +1,19 @@ @@ -138,7 +154,7 @@
{ + on:change={async (e) => { + await updateConfigHandler(); + await getVoices(); + await getModels(); + if (e.target.value === 'openai') { - getOpenAIVoices(); TTS_VOICE = 'alloy'; TTS_MODEL = 'tts-1'; } else { - getWebAPIVoices(); TTS_VOICE = ''; + TTS_MODEL = ''; } }} > +
@@ -203,7 +223,7 @@
+ {:else if TTS_ENGINE === 'elevenlabs'} +
+
+ +
+
{/if}
@@ -252,7 +283,7 @@ {#each voices as voice} - {/each} @@ -263,15 +294,56 @@
- + {#each models as model} - +
+
+ + + {:else if TTS_ENGINE === 'elevenlabs'} +
+
+
{$i18n.t('TTS Voice')}
+
+
+ + + + {#each voices as voice} + + {/each} + +
+
+
+
+
{$i18n.t('TTS Model')}
+
+
+ + + + {#each models as model} +
diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 6932961f34..1b0b2c3fa8 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -269,7 +269,7 @@ saveHandler(); }} > -
+
{$i18n.t('General Settings')}
@@ -615,7 +615,7 @@
{$i18n.t('Query Params')}
-
+
{$i18n.t('Top K')}
@@ -632,7 +632,7 @@
{#if querySettings.hybrid === true} -
+
{$i18n.t('Minimum Score')}
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 7f3c435a44..78a22010e5 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -80,6 +80,7 @@ let eventConfirmationMessage = ''; let eventConfirmationInput = false; let eventConfirmationInputPlaceholder = ''; + let eventConfirmationInputValue = ''; let eventCallback = null; let showModelSelector = true; @@ -108,7 +109,6 @@ }; let params = {}; - let valves = {}; $: if (history.currentId !== null) { let _messages = []; @@ -182,6 +182,7 @@ eventConfirmationTitle = data.title; eventConfirmationMessage = data.message; eventConfirmationInputPlaceholder = data.placeholder; + eventConfirmationInputValue = data?.value ?? ''; } else { console.log('Unknown message type', data); } @@ -281,6 +282,10 @@ if ($page.url.searchParams.get('q')) { prompt = $page.url.searchParams.get('q') ?? ''; + selectedToolIds = ($page.url.searchParams.get('tool_ids') ?? '') + .split(',') + .map((id) => id.trim()) + .filter((id) => id); if (prompt) { await tick(); @@ -706,6 +711,7 @@ let _response = null; const responseMessage = history.messages[responseMessageId]; + const userMessage = history.messages[responseMessage.parentId]; // Wait until history/message have been updated await tick(); @@ -772,11 +778,12 @@ if (model?.info?.meta?.knowledge ?? false) { files.push(...model.info.meta.knowledge); } - if (responseMessage?.files) { - files.push( - ...responseMessage?.files.filter((item) => ['web_search_results'].includes(item.type)) - ); - } + files.push( + ...(userMessage?.files ?? []).filter((item) => + ['doc', 'file', 'collection'].includes(item.type) + ), + ...(responseMessage?.files ?? []).filter((item) => ['web_search_results'].includes(item.type)) + ); eventTarget.dispatchEvent( new CustomEvent('chat:start', { @@ -808,7 +815,6 @@ keep_alive: $settings.keepAlive ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - ...(Object.keys(valves).length ? { valves } : {}), session_id: $socket?.id, chat_id: $chatId, id: responseMessageId @@ -1006,17 +1012,20 @@ const sendPromptOpenAI = async (model, userPrompt, responseMessageId, _chatId) => { let _response = null; + const responseMessage = history.messages[responseMessageId]; + const userMessage = history.messages[responseMessage.parentId]; let files = JSON.parse(JSON.stringify(chatFiles)); if (model?.info?.meta?.knowledge ?? false) { files.push(...model.info.meta.knowledge); } - if (responseMessage?.files) { - files.push( - ...responseMessage?.files.filter((item) => ['web_search_results'].includes(item.type)) - ); - } + files.push( + ...(userMessage?.files ?? []).filter((item) => + ['doc', 'file', 'collection'].includes(item.type) + ), + ...(responseMessage?.files ?? []).filter((item) => ['web_search_results'].includes(item.type)) + ); scrollToBottom(); @@ -1105,7 +1114,6 @@ max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - ...(Object.keys(valves).length ? { valves } : {}), session_id: $socket?.id, chat_id: $chatId, id: responseMessageId @@ -1484,6 +1492,7 @@ message={eventConfirmationMessage} input={eventConfirmationInput} inputPlaceholder={eventConfirmationInputPlaceholder} + inputValue={eventConfirmationInputValue} on:confirm={(e) => { if (e.detail) { eventCallback(e.detail); @@ -1631,7 +1640,6 @@ bind:show={showControls} bind:chatFiles bind:params - bind:valves />
{/if} diff --git a/src/lib/components/chat/ChatControls.svelte b/src/lib/components/chat/ChatControls.svelte index f67e6d6efd..3de095b0d9 100644 --- a/src/lib/components/chat/ChatControls.svelte +++ b/src/lib/components/chat/ChatControls.svelte @@ -9,9 +9,7 @@ export let models = []; export let chatId = null; - export let chatFiles = []; - export let valves = {}; export let params = {}; let largeScreen = false; @@ -50,7 +48,6 @@ }} {models} bind:chatFiles - bind:valves bind:params />
@@ -66,7 +63,6 @@ }} {models} bind:chatFiles - bind:valves bind:params />
diff --git a/src/lib/components/chat/Controls/Controls.svelte b/src/lib/components/chat/Controls/Controls.svelte index ee8fcfff33..69034a305a 100644 --- a/src/lib/components/chat/Controls/Controls.svelte +++ b/src/lib/components/chat/Controls/Controls.svelte @@ -5,13 +5,13 @@ import XMark from '$lib/components/icons/XMark.svelte'; import AdvancedParams from '../Settings/Advanced/AdvancedParams.svelte'; - import Valves from '$lib/components/common/Valves.svelte'; + import Valves from '$lib/components/chat/Controls/Valves.svelte'; import FileItem from '$lib/components/common/FileItem.svelte'; + import Collapsible from '$lib/components/common/Collapsible.svelte'; export let models = []; export let chatFiles = []; - export let valves = {}; export let params = {}; @@ -28,18 +28,17 @@
-
+
{#if chatFiles.length > 0} -
-
{$i18n.t('Files')}
- -
+ +
{#each chatFiles as file, fileIdx} { // Remove the file from the chatFiles array @@ -50,44 +49,38 @@ /> {/each}
-
+
{/if} - {#if models.length === 1 && models[0]?.pipe?.valves_spec} -
-
{$i18n.t('Valves')}
- -
- -
+ +
+
- -
- {/if} - -
-
{$i18n.t('System Prompt')}
- -
-