diff --git a/CHANGELOG.md b/CHANGELOG.md index f7ee69b3a7..016794f404 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,39 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.33] - 2024-10-24 + +### Added + +- **🏆 Evaluation Leaderboard**: Easily track your performance through a new leaderboard system where your ratings contribute to a real-time ranking based on the Elo system. Sibling responses (regenerations, many model chats) are required for your ratings to count in the leaderboard. Additionally, you can opt-in to share your feedback history and be part of the community-wide leaderboard. Expect further improvements as we refine the algorithm—help us build the best community leaderboard! +- **⚔️ Arena Model Evaluation**: Enable blind A/B testing of models directly from Admin Settings > Evaluation for a true side-by-side comparison. Ideal for pinpointing the best model for your needs. +- **🎯 Topic-Based Leaderboard**: Discover more accurate rankings with experimental topic-based reranking, which adjusts leaderboard standings based on tag similarity in feedback. Get more relevant insights based on specific topics! +- **📁 Folders Support for Chats**: Organize your chats better by grouping them into folders. Drag and drop chats between folders and export them seamlessly for easy sharing or analysis. +- **📤 Easy Chat Import via Drag & Drop**: Save time by simply dragging and dropping chat exports (JSON) directly onto the sidebar to import them into your workspace—streamlined, efficient, and intuitive! +- **📚 Enhanced Knowledge Collection**: Now, you can reference individual files from a knowledge collection—ideal for more precise Retrieval-Augmented Generations (RAG) queries and document analysis. +- **🏷️ Enhanced Tagging System**: Tags now take up less space! Utilize the new 'tag:' query system to manage, search, and organize your conversations more effectively without cluttering the interface. +- **🧠 Auto-Tagging for Chats**: Your conversations are now automatically tagged for improved organization, mirroring the efficiency of auto-generated titles. +- **🔍 Backend Chat Query System**: Chat filtering has become more efficient, now handled through the backend\*\* instead of your browser, improving search performance and accuracy. +- **🎮 Revamped Playground**: Experience a refreshed and optimized Playground for smoother testing, tweaks, and experimentation of your models and tools. +- **🧩 Token-Based Text Splitter**: Introducing token-based text splitting (tiktoken), giving you more precise control over how text is processed. Previously, only character-based splitting was available. +- **🔢 Ollama Batch Embeddings**: Leverage new batch embedding support for improved efficiency and performance with Ollama embedding models. +- **🔍 Enhanced Add Text Content Modal**: Enjoy a cleaner, more intuitive workflow for adding and curating knowledge content with an upgraded input modal from our Knowledge workspace. +- **🖋️ Rich Text Input for Chats**: Make your chat inputs more dynamic with support for rich text formatting. Your conversations just got a lot more polished and professional. +- **⚡ Faster Whisper Model Configurability**: Customize your local faster whisper model directly from the WebUI. +- **☁️ Experimental S3 Support**: Enable stateless WebUI instances with S3 support, greatly enhancing scalability and balancing heavy workloads. +- **🔕 Disable Update Toast**: Now you can streamline your workspace even further—choose to disable update notifications for a more focused experience. +- **🌟 RAG Citation Relevance Percentage**: Easily assess citation accuracy with the addition of relevance percentages in RAG results. +- **⚙️ Mermaid Copy Button**: Mermaid diagrams now come with a handy copy button, simplifying the extraction and use of diagram contents directly in your workflow. +- **🎨 UI Redesign**: Major interface redesign that will make navigation smoother, keep your focus where it matters, and ensure a modern look. + +### Fixed + +- **🎙️ Voice Note Mic Stopping Issue**: Fixed the issue where the microphone stayed active after ending a voice note recording, ensuring your audio workflow runs smoothly. + +### Removed + +- **👋 Goodbye Sidebar Tags**: Sidebar tag clutter is gone. We’ve shifted tag buttons to more effective query-based tag filtering for a sleeker, more agile interface. + ## [0.3.32] - 2024-10-06 ### Added diff --git a/Dockerfile b/Dockerfile index 2e898dc890..ec879d732d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1 # Initialize device type args -# use build args in the docker build commmand with --build-arg="BUILDARG=true" +# use build args in the docker build command with --build-arg="BUILDARG=true" ARG USE_CUDA=false ARG USE_OLLAMA=false # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) @@ -11,6 +11,10 @@ ARG USE_CUDA_VER=cu121 # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_RERANKING_MODEL="" + +# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken +ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base" + ARG BUILD_HASH=dev-build # Override at your own risk - non-root configurations are untested ARG UID=0 @@ -72,6 +76,10 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" +## Tiktoken model settings ## +ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \ + TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken" + ## Hugging Face download cache ## ENV HF_HOME="/app/backend/data/cache/embedding/models" @@ -131,11 +139,13 @@ RUN pip3 install uv && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ else \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ fi; \ chown -R $UID:$GID /app/backend/data/ diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index 9bf242381c..83251a3a91 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -18,7 +18,7 @@ If you're experiencing connection issues, it’s often due to the WebUI docker c docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main ``` -### Error on Slow Reponses for Ollama +### Error on Slow Responses for Ollama Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds. diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index 0e56720138..148430da87 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY app.state.config.STT_ENGINE = AUDIO_STT_ENGINE app.state.config.STT_MODEL = AUDIO_STT_MODEL +app.state.config.WHISPER_MODEL = WHISPER_MODEL +app.state.faster_whisper_model = None + app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE @@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) +def set_faster_whisper_model(model: str, auto_update: bool = False): + if model and app.state.config.STT_ENGINE == "": + from faster_whisper import WhisperModel + + faster_whisper_kwargs = { + "model_size_or_path": model, + "device": whisper_device_type, + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not auto_update, + } + + try: + app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) + except Exception: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + faster_whisper_kwargs["local_files_only"] = False + app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) + + else: + app.state.faster_whisper_model = None + + class TTSConfigForm(BaseModel): OPENAI_API_BASE_URL: str OPENAI_API_KEY: str @@ -99,6 +127,7 @@ class STTConfigForm(BaseModel): OPENAI_API_KEY: str ENGINE: str MODEL: str + WHISPER_MODEL: str class AudioConfigUpdateForm(BaseModel): @@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)): "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "ENGINE": app.state.config.STT_ENGINE, "MODEL": app.state.config.STT_MODEL, + "WHISPER_MODEL": app.state.config.WHISPER_MODEL, }, } @@ -176,6 +206,8 @@ async def update_audio_config( app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY app.state.config.STT_ENGINE = form_data.stt.ENGINE app.state.config.STT_MODEL = form_data.stt.MODEL + app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL + set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE) return { "tts": { @@ -194,6 +226,7 @@ async def update_audio_config( "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "ENGINE": app.state.config.STT_ENGINE, "MODEL": app.state.config.STT_MODEL, + "WHISPER_MODEL": app.state.config.WHISPER_MODEL, }, } @@ -367,27 +400,10 @@ def transcribe(file_path): id = filename.split(".")[0] if app.state.config.STT_ENGINE == "": - from faster_whisper import WhisperModel - - whisper_kwargs = { - "model_size_or_path": WHISPER_MODEL, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, - } - - log.debug(f"whisper_kwargs: {whisper_kwargs}") - - try: - model = WhisperModel(**whisper_kwargs) - except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - whisper_kwargs["local_files_only"] = False - model = WhisperModel(**whisper_kwargs) + if app.state.faster_whisper_model is None: + set_faster_whisper_model(app.state.config.WHISPER_MODEL) + model = app.state.faster_whisper_model segments, info = model.transcribe(file_path, beam_size=5) log.info( "Detected language '%s' with probability %f" @@ -395,7 +411,6 @@ def transcribe(file_path): ) transcript = "".join([segment.text for segment in list(segments)]) - data = {"text": transcript.strip()} # save the transcript to a json file @@ -403,7 +418,7 @@ def transcribe(file_path): with open(transcript_file, "w") as f: json.dump(data, f) - print(data) + log.debug(data) return data elif app.state.config.STT_ENGINE == "openai": if is_mp4_audio(file_path): @@ -417,7 +432,7 @@ def transcribe(file_path): files = {"file": (filename, open(file_path, "rb"))} data = {"model": app.state.config.STT_MODEL} - print(files, data) + log.debug(files, data) r = None try: @@ -450,7 +465,7 @@ def transcribe(file_path): except Exception: error_detail = f"External: {e}" - raise error_detail + raise Exception(error_detail) @app.post("/transcriptions") diff --git a/backend/open_webui/apps/images/utils/comfyui.py b/backend/open_webui/apps/images/utils/comfyui.py index 0a3e3a1d9b..4c421d7c52 100644 --- a/backend/open_webui/apps/images/utils/comfyui.py +++ b/backend/open_webui/apps/images/utils/comfyui.py @@ -125,22 +125,34 @@ async def comfyui_generate_image( workflow[node_id]["inputs"][node.key] = model elif node.type == "prompt": for node_id in node.node_ids: - workflow[node_id]["inputs"]["text"] = payload.prompt + workflow[node_id]["inputs"][ + node.key if node.key else "text" + ] = payload.prompt elif node.type == "negative_prompt": for node_id in node.node_ids: - workflow[node_id]["inputs"]["text"] = payload.negative_prompt + workflow[node_id]["inputs"][ + node.key if node.key else "text" + ] = payload.negative_prompt elif node.type == "width": for node_id in node.node_ids: - workflow[node_id]["inputs"]["width"] = payload.width + workflow[node_id]["inputs"][ + node.key if node.key else "width" + ] = payload.width elif node.type == "height": for node_id in node.node_ids: - workflow[node_id]["inputs"]["height"] = payload.height + workflow[node_id]["inputs"][ + node.key if node.key else "height" + ] = payload.height elif node.type == "n": for node_id in node.node_ids: - workflow[node_id]["inputs"]["batch_size"] = payload.n + workflow[node_id]["inputs"][ + node.key if node.key else "batch_size" + ] = payload.n elif node.type == "steps": for node_id in node.node_ids: - workflow[node_id]["inputs"]["steps"] = payload.steps + workflow[node_id]["inputs"][ + node.key if node.key else "steps" + ] = payload.steps elif node.type == "seed": seed = ( payload.seed diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 33d9846557..cb38a53eb6 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel): class GenerateEmbedForm(BaseModel): model: str - input: str - truncate: Optional[bool] + input: list[str] | str + truncate: Optional[bool] = None options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None @@ -560,48 +560,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embed", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return generate_ollama_batch_embeddings(form_data, url_idx) @app.post("/api/embeddings") @@ -611,48 +570,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) def generate_ollama_embeddings( @@ -692,7 +610,64 @@ def generate_ollama_embeddings( log.info(f"generate_ollama_embeddings {data}") if "embedding" in data: - return data["embedding"] + return data + else: + raise Exception("Something went wrong :/") + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except Exception: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +def generate_ollama_batch_embeddings( + form_data: GenerateEmbedForm, + url_idx: Optional[int] = None, +): + log.info(f"generate_ollama_batch_embeddings {form_data}") + + if url_idx is None: + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={"Content-Type": "application/json"}, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + try: + r.raise_for_status() + + data = r.json() + + log.info(f"generate_ollama_batch_embeddings {data}") + + if "embeddings" in data: + return data else: raise Exception("Something went wrong :/") except Exception as e: @@ -788,8 +763,7 @@ async def generate_chat_completion( user=Depends(get_verified_user), ): payload = {**form_data.model_dump(exclude_none=True)} - log.debug(f"{payload = }") - + log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: del payload["metadata"] @@ -824,7 +798,7 @@ async def generate_chat_completion( url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") - log.debug(payload) + log.debug(f"generate_chat_completion() - 2.payload = {payload}") return await post_streaming_url( f"{url}/api/chat", diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 70cefb29ca..3647977cad 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -18,7 +18,10 @@ from open_webui.config import ( OPENAI_API_KEYS, AppConfig, ) -from open_webui.env import AIOHTTP_CLIENT_TIMEOUT +from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, +) from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS @@ -179,7 +182,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): async def fetch_url(url, key): - timeout = aiohttp.ClientTimeout(total=3) + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: headers = {"Authorization": f"Bearer {key}"} async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: @@ -237,9 +240,7 @@ def merge_models_lists(model_lists): def is_openai_api_disabled(): - api_keys = app.state.config.OPENAI_API_KEYS - no_keys = len(api_keys) == 1 and api_keys[0] == "" - return no_keys or not app.state.config.ENABLE_OPENAI_API + return not app.state.config.ENABLE_OPENAI_API async def get_all_models_raw() -> list: diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 52cebeabc4..04eece38c6 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -15,6 +15,9 @@ from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, sta from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel + +from open_webui.storage.provider import Storage +from open_webui.apps.webui.models.knowledge import Knowledges from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT # Document loaders @@ -47,6 +50,8 @@ from open_webui.apps.retrieval.utils import ( from open_webui.apps.webui.models.files import Files from open_webui.config import ( BRAVE_SEARCH_API_KEY, + TIKTOKEN_ENCODING_NAME, + RAG_TEXT_SPLITTER, CHUNK_OVERLAP, CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, @@ -63,7 +68,7 @@ from open_webui.config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_OPENAI_BATCH_SIZE, + RAG_EMBEDDING_BATCH_SIZE, RAG_FILE_MAX_COUNT, RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, @@ -102,7 +107,7 @@ from open_webui.utils.misc import ( ) from open_webui.utils.utils import get_admin_user, get_verified_user -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter from langchain_community.document_loaders import ( YoutubeLoader, ) @@ -129,12 +134,15 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL +app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME + app.state.config.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE +app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_TEMPLATE = RAG_TEMPLATE @@ -171,9 +179,9 @@ def update_embedding_model( auto_update: bool = False, ): if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": - import sentence_transformers + from sentence_transformers import SentenceTransformer - app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( + app.state.sentence_transformer_ef = SentenceTransformer( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, @@ -233,7 +241,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) app.add_middleware( @@ -267,7 +275,7 @@ async def get_status(): "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL, - "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, } @@ -277,10 +285,10 @@ async def get_embedding_config(user=Depends(get_admin_user)): "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, - "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, }, } @@ -296,13 +304,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)): class OpenAIConfigForm(BaseModel): url: str key: str - batch_size: Optional[int] = None class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None embedding_engine: str embedding_model: str + embedding_batch_size: Optional[int] = 1 @app.post("/embedding/update") @@ -320,11 +328,7 @@ async def update_embedding_config( if form_data.openai_config is not None: app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_KEY = form_data.openai_config.key - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = ( - form_data.openai_config.batch_size - if form_data.openai_config.batch_size - else 1 - ) + app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) @@ -334,17 +338,17 @@ async def update_embedding_config( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) return { "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, - "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, }, } except Exception as e: @@ -388,18 +392,19 @@ async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, - "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, - }, "content_extraction": { "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, "tika_server_url": app.state.config.TIKA_SERVER_URL, }, "chunk": { + "text_splitter": app.state.config.TEXT_SPLITTER, "chunk_size": app.state.config.CHUNK_SIZE, "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, + "file": { + "max_size": app.state.config.FILE_MAX_SIZE, + "max_count": app.state.config.FILE_MAX_COUNT, + }, "youtube": { "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, @@ -438,6 +443,7 @@ class ContentExtractionConfig(BaseModel): class ChunkParamUpdateForm(BaseModel): + text_splitter: Optional[str] = None chunk_size: int chunk_overlap: int @@ -497,6 +503,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url if form_data.chunk is not None: + app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap @@ -543,6 +550,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "tika_server_url": app.state.config.TIKA_SERVER_URL, }, "chunk": { + "text_splitter": app.state.config.TEXT_SPLITTER, "chunk_size": app.state.config.CHUNK_SIZE, "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, @@ -603,11 +611,10 @@ class QuerySettingsForm(BaseModel): async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - app.state.config.RAG_TEMPLATE = ( - form_data.template if form_data.template != "" else DEFAULT_RAG_TEMPLATE - ) + app.state.config.RAG_TEMPLATE = form_data.template app.state.config.TOP_K = form_data.k if form_data.k else 4 app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False ) @@ -645,25 +652,48 @@ def save_docs_to_vector_db( filter={"hash": metadata["hash"]}, ) - if result: + if result is not None: existing_doc_ids = result.ids[0] if existing_doc_ids: log.info(f"Document with hash {metadata['hash']} already exists") raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) + if app.state.config.TEXT_SPLITTER in ["", "character"]: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + elif app.state.config.TEXT_SPLITTER == "token": + text_splitter = TokenTextSplitter( + encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME, + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + else: + raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) + docs = text_splitter.split_documents(docs) if len(docs) == 0: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) texts = [doc.page_content for doc in docs] - metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] + metadatas = [ + { + **doc.metadata, + **(metadata if metadata else {}), + "embedding_config": json.dumps( + { + "engine": app.state.config.RAG_EMBEDDING_ENGINE, + "model": app.state.config.RAG_EMBEDDING_MODEL, + } + ), + } + for doc in docs + ] # ChromaDB does not like datetime formats # for meta-data so convert them to string. @@ -679,8 +709,10 @@ def save_docs_to_vector_db( if overwrite: VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) log.info(f"deleting existing collection {collection_name}") - - if add is False: + elif add is False: + log.info( + f"collection {collection_name} already exists, overwrite is False and add is False" + ) return True log.info(f"adding to collection {collection_name}") @@ -690,7 +722,7 @@ def save_docs_to_vector_db( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) embeddings = embedding_function( @@ -767,7 +799,7 @@ def process_file( collection_name=f"file-{file.id}", filter={"file_id": file.id} ) - if len(result.ids[0]) > 0: + if result is not None and len(result.ids[0]) > 0: docs = [ Document( page_content=result.documents[0][idx], @@ -792,15 +824,14 @@ def process_file( else: # Process the file and save the content # Usage: /files/ - - file_path = file.meta.get("path", None) + file_path = file.path if file_path: + file_path = Storage.get_file(file_path) loader = Loader( engine=app.state.config.CONTENT_EXTRACTION_ENGINE, TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, ) - docs = loader.load( file.filename, file.meta.get("content_type"), file_path ) @@ -816,7 +847,6 @@ def process_file( }, ) ] - text_content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {text_content}") @@ -1259,6 +1289,7 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin @app.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): VECTOR_DB_CLIENT.reset() + Knowledges.delete_all_knowledge() @app.post("/reset/uploads") @@ -1281,28 +1312,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool: print(f"The directory {folder} does not exist") except Exception as e: print(f"Failed to process the directory {folder}. Reason: {e}") - - return True - - -@app.post("/reset") -def reset(user=Depends(get_admin_user)) -> bool: - folder = f"{UPLOAD_DIR}" - for filename in os.listdir(folder): - file_path = os.path.join(folder, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - log.error("Failed to delete %s. Reason: %s" % (file_path, e)) - - try: - VECTOR_DB_CLIENT.reset() - except Exception as e: - log.exception(e) - return True diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 0fe206c966..153bd804ff 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -12,13 +12,14 @@ from langchain_core.documents import Document from open_webui.apps.ollama.main import ( - GenerateEmbeddingsForm, - generate_ollama_embeddings, + GenerateEmbedForm, + generate_ollama_batch_embeddings, ) from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import DEFAULT_RAG_TEMPLATE log = logging.getLogger(__name__) @@ -193,7 +194,8 @@ def query_collection( k=k, query_embedding=query_embedding, ) - results.append(result.model_dump()) + if result is not None: + results.append(result.model_dump()) except Exception as e: log.exception(f"Error when querying the collection: {e}") else: @@ -238,8 +240,13 @@ def query_collection_with_hybrid_search( def rag_template(template: str, context: str, query: str): - count = template.count("[context]") - assert "[context]" in template, "RAG template does not contain '[context]'" + if template == "": + template = DEFAULT_RAG_TEMPLATE + + if "[context]" not in template and "{{CONTEXT}}" not in template: + log.debug( + "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder." + ) if "" in context and "" in context: log.debug( @@ -248,14 +255,25 @@ def rag_template(template: str, context: str, query: str): "nothing, or the user might be trying to hack something." ) + query_placeholders = [] if "[query]" in context: - query_placeholder = f"[query-{str(uuid.uuid4())}]" + query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" template = template.replace("[query]", query_placeholder) - template = template.replace("[context]", context) + query_placeholders.append(query_placeholder) + + if "{{QUERY}}" in context: + query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" + template = template.replace("{{QUERY}}", query_placeholder) + query_placeholders.append(query_placeholder) + + template = template.replace("[context]", context) + template = template.replace("{{CONTEXT}}", context) + template = template.replace("[query]", query) + template = template.replace("{{QUERY}}", query) + + for query_placeholder in query_placeholders: template = template.replace(query_placeholder, query) - else: - template = template.replace("[context]", context) - template = template.replace("[query]", query) + return template @@ -265,39 +283,27 @@ def get_embedding_function( embedding_function, openai_key, openai_url, - batch_size, + embedding_batch_size, ): if embedding_engine == "": return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: - if embedding_engine == "ollama": - func = lambda query: generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": embedding_model, - "prompt": query, - } - ) - ) - elif embedding_engine == "openai": - func = lambda query: generate_openai_embeddings( - model=embedding_model, - text=query, - key=openai_key, - url=openai_url, - ) + func = lambda query: generate_embeddings( + engine=embedding_engine, + model=embedding_model, + text=query, + key=openai_key if embedding_engine == "openai" else "", + url=openai_url if embedding_engine == "openai" else "", + ) - def generate_multiple(query, f): + def generate_multiple(query, func): if isinstance(query, list): - if embedding_engine == "openai": - embeddings = [] - for i in range(0, len(query), batch_size): - embeddings.extend(f(query[i : i + batch_size])) - return embeddings - else: - return [f(q) for q in query] + embeddings = [] + for i in range(0, len(query), embedding_batch_size): + embeddings.extend(func(query[i : i + embedding_batch_size])) + return embeddings else: - return f(query) + return func(query) return lambda query: generate_multiple(query, func) @@ -379,6 +385,8 @@ def get_rag_context( extracted_collections.extend(collection_names) if context: + if "data" in file: + del file["data"] relevant_contexts.append({**context, "file": file}) contexts = [] @@ -386,23 +394,37 @@ def get_rag_context( for context in relevant_contexts: try: if "documents" in context: + file_names = list( + set( + [ + metadata["name"] + for metadata in context["metadatas"][0] + if metadata is not None and "name" in metadata + ] + ) + ) contexts.append( - "\n\n".join( + ((", ".join(file_names) + ":\n\n") if file_names else "") + + "\n\n".join( [text for text in context["documents"][0] if text is not None] ) ) if "metadatas" in context: - citations.append( - { - "source": context["file"], - "document": context["documents"][0], - "metadata": context["metadatas"][0], - } - ) + citation = { + "source": context["file"], + "document": context["documents"][0], + "metadata": context["metadatas"][0], + } + if "distances" in context and context["distances"]: + citation["distances"] = context["distances"][0] + citations.append(citation) except Exception as e: log.exception(e) + print("contexts", contexts) + print("citations", citations) + return contexts, citations @@ -444,20 +466,6 @@ def get_model_path(model: str, update_model: bool = False): return model -def generate_openai_embeddings( - model: str, - text: Union[str, list[str]], - key: str, - url: str = "https://api.openai.com/v1", -): - if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, key, url) - else: - embeddings = generate_openai_batch_embeddings(model, [text], key, url) - - return embeddings[0] if isinstance(text, str) else embeddings - - def generate_openai_batch_embeddings( model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" ) -> Optional[list[list[float]]]: @@ -481,6 +489,33 @@ def generate_openai_batch_embeddings( return None +def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): + if engine == "ollama": + if isinstance(text, list): + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": text}) + ) + else: + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": [text]}) + ) + return ( + embeddings["embeddings"][0] + if isinstance(text, str) + else embeddings["embeddings"] + ) + elif engine == "openai": + key = kwargs.get("key", "") + url = kwargs.get("url", "https://api.openai.com/v1") + + if isinstance(text, list): + embeddings = generate_openai_batch_embeddings(model, text, key, url) + else: + embeddings = generate_openai_batch_embeddings(model, [text], key, url) + + return embeddings[0] if isinstance(text, str) else embeddings + + import operator from typing import Optional, Sequence diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py index 1f33b17219..c7f00f5fd1 100644 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ b/backend/open_webui/apps/retrieval/vector/connector.py @@ -4,6 +4,10 @@ if VECTOR_DB == "milvus": from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient VECTOR_DB_CLIENT = MilvusClient() +elif VECTOR_DB == "qdrant": + from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient + + VECTOR_DB_CLIENT = QdrantClient() else: from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py index 84f80b2531..7782671a23 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py @@ -109,7 +109,9 @@ class ChromaClient: def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. - collection = self.client.get_or_create_collection(name=collection_name) + collection = self.client.get_or_create_collection( + name=collection_name, metadata={"hnsw:space": "cosine"} + ) ids = [item["id"] for item in items] documents = [item["text"] for item in items] @@ -127,7 +129,9 @@ class ChromaClient: def upsert(self, collection_name: str, items: list[VectorItem]): # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. - collection = self.client.get_or_create_collection(name=collection_name) + collection = self.client.get_or_create_collection( + name=collection_name, metadata={"hnsw:space": "cosine"} + ) ids = [item["id"] for item in items] documents = [item["text"] for item in items] diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py new file mode 100644 index 0000000000..c1e06872f9 --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py @@ -0,0 +1,179 @@ +from typing import Optional + +from qdrant_client import QdrantClient as Qclient +from qdrant_client.http.models import PointStruct +from qdrant_client.models import models + +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import QDRANT_URI + +NO_LIMIT = 999999999 + + +class QdrantClient: + def __init__(self): + self.collection_prefix = "open-webui" + self.QDRANT_URI = QDRANT_URI + self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None + + def _result_to_get_result(self, points) -> GetResult: + ids = [] + documents = [] + metadatas = [] + + for point in points: + payload = point.payload + ids.append(point.id) + documents.append(payload["text"]) + metadatas.append(payload["metadata"]) + + return GetResult( + **{ + "ids": [ids], + "documents": [documents], + "metadatas": [metadatas], + } + ) + + def _create_collection(self, collection_name: str, dimension: int): + collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" + self.client.create_collection( + collection_name=collection_name_with_prefix, + vectors_config=models.VectorParams( + size=dimension, distance=models.Distance.COSINE + ), + ) + + print(f"collection {collection_name_with_prefix} successfully created!") + + def _create_collection_if_not_exists(self, collection_name, dimension): + if not self.has_collection(collection_name=collection_name): + self._create_collection( + collection_name=collection_name, dimension=dimension + ) + + def _create_points(self, items: list[VectorItem]): + return [ + PointStruct( + id=item["id"], + vector=item["vector"], + payload={"text": item["text"], "metadata": item["metadata"]}, + ) + for item in items + ] + + def has_collection(self, collection_name: str) -> bool: + return self.client.collection_exists( + f"{self.collection_prefix}_{collection_name}" + ) + + def delete_collection(self, collection_name: str): + return self.client.delete_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[SearchResult]: + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + if limit is None: + limit = NO_LIMIT # otherwise qdrant would set limit to 10! + + query_response = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + query=vectors[0], + limit=limit, + ) + get_result = self._result_to_get_result(query_response.points) + return SearchResult( + ids=get_result.ids, + documents=get_result.documents, + metadatas=get_result.metadatas, + distances=[[point.score for point in query_response.points]], + ) + + def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + # Construct the filter string for querying + if not self.has_collection(collection_name): + return None + try: + if limit is None: + limit = NO_LIMIT # otherwise qdrant would set limit to 10! + + field_conditions = [] + for key, value in filter.items(): + field_conditions.append( + models.FieldCondition( + key=f"metadata.{key}", match=models.MatchValue(value=value) + ) + ) + + points = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + query_filter=models.Filter(should=field_conditions), + limit=limit, + ) + return self._result_to_get_result(points.points) + except Exception as e: + print(e) + return None + + def get(self, collection_name: str) -> Optional[GetResult]: + # Get all the items in the collection. + points = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + limit=NO_LIMIT, # otherwise qdrant would set limit to 10! + ) + return self._result_to_get_result(points.points) + + def insert(self, collection_name: str, items: list[VectorItem]): + # Insert the items into the collection, if the collection does not exist, it will be created. + self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + points = self._create_points(items) + self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points) + + def upsert(self, collection_name: str, items: list[VectorItem]): + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + points = self._create_points(items) + return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) + + def delete( + self, + collection_name: str, + ids: Optional[list[str]] = None, + filter: Optional[dict] = None, + ): + # Delete the items from the collection based on the ids. + field_conditions = [] + + if ids: + for id_value in ids: + field_conditions.append( + models.FieldCondition( + key="metadata.id", + match=models.MatchValue(value=id_value), + ), + ), + elif filter: + for key, value in filter.items(): + field_conditions.append( + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ), + + return self.client.delete( + collection_name=f"{self.collection_prefix}_{collection_name}", + points_selector=models.FilterSelector( + filter=models.Filter(must=field_conditions) + ), + ) + + def reset(self): + # Resets the database. This will delete all collections and item entries. + collection_names = self.client.get_collections().collections + for collection_name in collection_names: + if collection_name.name.startswith(self.collection_prefix): + self.client.delete_collection(collection_name=collection_name.name) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index b36db10546..5a0a83961b 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -1,6 +1,7 @@ import inspect import json import logging +import time from typing import AsyncGenerator, Generator, Iterator from open_webui.apps.socket.main import get_event_call, get_event_emitter @@ -9,6 +10,7 @@ from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.routers import ( auths, chats, + folders, configs, files, functions, @@ -16,6 +18,7 @@ from open_webui.apps.webui.routers import ( models, knowledge, prompts, + evaluations, tools, users, utils, @@ -31,10 +34,17 @@ from open_webui.config import ( ENABLE_LOGIN_FORM, ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, + ENABLE_EVALUATION_ARENA_MODELS, + EVALUATION_ARENA_MODELS, + DEFAULT_ARENA_MODEL, JWT_EXPIRES_IN, + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, SHOW_ADMIN_DETAILS, USER_PERMISSIONS, WEBHOOK_URL, @@ -89,10 +99,18 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING +app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS +app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS + app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} @@ -107,19 +125,24 @@ app.add_middleware( app.include_router(configs.router, prefix="/configs", tags=["configs"]) + app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) + app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) - -app.include_router(files.router, prefix="/files", tags=["files"]) app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) +app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) + +app.include_router(folders.router, prefix="/folders", tags=["folders"]) +app.include_router(files.router, prefix="/files", tags=["files"]) + app.include_router(utils.router, prefix="/utils", tags=["utils"]) @@ -133,6 +156,47 @@ async def get_status(): } +async def get_all_models(): + models = [] + pipe_models = await get_pipe_models() + models = models + pipe_models + + if app.state.config.ENABLE_EVALUATION_ARENA_MODELS: + arena_models = [] + if len(app.state.config.EVALUATION_ARENA_MODELS) > 0: + arena_models = [ + { + "id": model["id"], + "name": model["name"], + "info": { + "meta": model["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + for model in app.state.config.EVALUATION_ARENA_MODELS + ] + else: + # Add default arena model + arena_models = [ + { + "id": DEFAULT_ARENA_MODEL["id"], + "name": DEFAULT_ARENA_MODEL["name"], + "info": { + "meta": DEFAULT_ARENA_MODEL["meta"], + }, + "object": "model", + "created": int(time.time()), + "owned_by": "arena", + "arena": True, + } + ] + models = models + arena_models + return models + + def get_function_module(pipe_id: str): # Check if function is already loaded if pipe_id not in app.state.FUNCTIONS: diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index f364dcc700..f6a1e45483 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -4,8 +4,13 @@ import uuid from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.tags import TagModel, Tag, Tags + + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import or_, func, select, and_, text +from sqlalchemy.sql import exists #################### # Chat DB Schema @@ -18,13 +23,17 @@ class Chat(Base): id = Column(String, primary_key=True) user_id = Column(String) title = Column(Text) - chat = Column(Text) # Save Chat JSON as Text + chat = Column(JSON) created_at = Column(BigInteger) updated_at = Column(BigInteger) share_id = Column(Text, unique=True, nullable=True) archived = Column(Boolean, default=False) + pinned = Column(Boolean, default=False, nullable=True) + + meta = Column(JSON, server_default="{}") + folder_id = Column(Text, nullable=True) class ChatModel(BaseModel): @@ -33,13 +42,17 @@ class ChatModel(BaseModel): id: str user_id: str title: str - chat: str + chat: dict created_at: int # timestamp in epoch updated_at: int # timestamp in epoch share_id: Optional[str] = None archived: bool = False + pinned: Optional[bool] = False + + meta: dict = {} + folder_id: Optional[str] = None #################### @@ -51,6 +64,17 @@ class ChatForm(BaseModel): chat: dict +class ChatImportForm(ChatForm): + meta: Optional[dict] = {} + pinned: Optional[bool] = False + folder_id: Optional[str] = None + + +class ChatTitleMessagesForm(BaseModel): + title: str + messages: list[dict] + + class ChatTitleForm(BaseModel): title: str @@ -64,6 +88,9 @@ class ChatResponse(BaseModel): created_at: int # timestamp in epoch share_id: Optional[str] = None # id of the chat to be shared archived: bool + pinned: Optional[bool] = False + meta: dict = {} + folder_id: Optional[str] = None class ChatTitleIdResponse(BaseModel): @@ -86,7 +113,36 @@ class ChatTable: if "title" in form_data.chat else "New Chat" ), - "chat": json.dumps(form_data.chat), + "chat": form_data.chat, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + result = Chat(**chat.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return ChatModel.model_validate(result) if result else None + + def import_chat( + self, user_id: str, form_data: ChatImportForm + ) -> Optional[ChatModel]: + with get_db() as db: + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" + ), + "chat": form_data.chat, + "meta": form_data.meta, + "pinned": form_data.pinned, + "folder_id": form_data.folder_id, "created_at": int(time.time()), "updated_at": int(time.time()), } @@ -101,14 +157,14 @@ class ChatTable: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: with get_db() as db: - chat_obj = db.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) + chat_item = db.get(Chat, id) + chat_item.chat = chat + chat_item.title = chat["title"] if "title" in chat else "New Chat" + chat_item.updated_at = int(time.time()) db.commit() - db.refresh(chat_obj) + db.refresh(chat_item) - return ChatModel.model_validate(chat_obj) + return ChatModel.model_validate(chat_item) except Exception: return None @@ -182,11 +238,24 @@ class ChatTable: except Exception: return None + def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: + try: + with get_db() as db: + chat = db.get(Chat, id) + chat.pinned = not chat.pinned + chat.updated_at = int(time.time()) + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception: + return None + def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) chat.archived = not chat.archived + chat.updated_at = int(time.time()) db.commit() db.refresh(chat) return ChatModel.model_validate(chat) @@ -223,14 +292,18 @@ class ChatTable: limit: int = 50, ) -> list[ChatModel]: with get_db() as db: - query = db.query(Chat).filter_by(user_id=user_id) + query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) if not include_archived: query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) + + query = query.order_by(Chat.updated_at.desc()) + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_title_id_list_by_user_id( @@ -241,7 +314,9 @@ class ChatTable: limit: Optional[int] = None, ) -> list[ChatTitleIdResponse]: with get_db() as db: - query = db.query(Chat).filter_by(user_id=user_id) + query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) + query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) + if not include_archived: query = query.filter_by(archived=False) @@ -249,10 +324,10 @@ class ChatTable: Chat.id, Chat.title, Chat.updated_at, Chat.created_at ) - if limit: - query = query.limit(limit) if skip: query = query.offset(skip) + if limit: + query = query.limit(limit) all_chats = query.all() @@ -328,6 +403,15 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: + with get_db() as db: + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, pinned=True, archived=False) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] + def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( @@ -337,6 +421,329 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chats_by_user_id_and_search_text( + self, + user_id: str, + search_text: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 60, + ) -> list[ChatModel]: + """ + Filters chats based on a search query using Python, allowing pagination using skip and limit. + """ + search_text = search_text.lower().strip() + + if not search_text: + return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) + + search_text_words = search_text.split(" ") + + # search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags + tag_ids = [ + word.replace("tag:", "").replace(" ", "_").lower() + for word in search_text_words + if word.startswith("tag:") + ] + + search_text_words = [ + word for word in search_text_words if not word.startswith("tag:") + ] + + search_text = " ".join(search_text_words) + + with get_db() as db: + query = db.query(Chat).filter(Chat.user_id == user_id) + + if not include_archived: + query = query.filter(Chat.archived == False) + + query = query.order_by(Chat.updated_at.desc()) + + # Check if the database dialect is either 'sqlite' or 'postgresql' + dialect_name = db.bind.dialect.name + if dialect_name == "sqlite": + # SQLite case: using JSON1 extension for JSON searching + query = query.filter( + ( + Chat.title.ilike( + f"%{search_text}%" + ) # Case-insensitive search in title + | text( + """ + EXISTS ( + SELECT 1 + FROM json_each(Chat.chat, '$.messages') AS message + WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%' + ) + """ + ) + ).params(search_text=search_text) + ) + + # Check if there are any tags to filter, it should have all the tags + if "none" in tag_ids: + query = query.filter( + text( + """ + NOT EXISTS ( + SELECT 1 + FROM json_each(Chat.meta, '$.tags') AS tag + ) + """ + ) + ) + elif tag_ids: + query = query.filter( + and_( + *[ + text( + f""" + EXISTS ( + SELECT 1 + FROM json_each(Chat.meta, '$.tags') AS tag + WHERE tag.value = :tag_id_{tag_idx} + ) + """ + ).params(**{f"tag_id_{tag_idx}": tag_id}) + for tag_idx, tag_id in enumerate(tag_ids) + ] + ) + ) + + elif dialect_name == "postgresql": + # PostgreSQL relies on proper JSON query for search + query = query.filter( + ( + Chat.title.ilike( + f"%{search_text}%" + ) # Case-insensitive search in title + | text( + """ + EXISTS ( + SELECT 1 + FROM json_array_elements(Chat.chat->'messages') AS message + WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%' + ) + """ + ) + ).params(search_text=search_text) + ) + + # Check if there are any tags to filter, it should have all the tags + if "none" in tag_ids: + query = query.filter( + text( + """ + NOT EXISTS ( + SELECT 1 + FROM json_array_elements_text(Chat.meta->'tags') AS tag + ) + """ + ) + ) + elif tag_ids: + query = query.filter( + and_( + *[ + text( + f""" + EXISTS ( + SELECT 1 + FROM json_array_elements_text(Chat.meta->'tags') AS tag + WHERE tag = :tag_id_{tag_idx} + ) + """ + ).params(**{f"tag_id_{tag_idx}": tag_id}) + for tag_idx, tag_id in enumerate(tag_ids) + ] + ) + ) + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + # Perform pagination at the SQL level + all_chats = query.offset(skip).limit(limit).all() + + print(len(all_chats)) + + # Validate and return chats + return [ChatModel.model_validate(chat) for chat in all_chats] + + def get_chats_by_folder_id_and_user_id( + self, folder_id: str, user_id: str + ) -> list[ChatModel]: + with get_db() as db: + query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) + query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) + query = query.filter_by(archived=False) + + query = query.order_by(Chat.updated_at.desc()) + + all_chats = query.all() + return [ChatModel.model_validate(chat) for chat in all_chats] + + def get_chats_by_folder_ids_and_user_id( + self, folder_ids: list[str], user_id: str + ) -> list[ChatModel]: + with get_db() as db: + query = db.query(Chat).filter( + Chat.folder_id.in_(folder_ids), Chat.user_id == user_id + ) + query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) + query = query.filter_by(archived=False) + + query = query.order_by(Chat.updated_at.desc()) + + all_chats = query.all() + return [ChatModel.model_validate(chat) for chat in all_chats] + + def update_chat_folder_id_by_id_and_user_id( + self, id: str, user_id: str, folder_id: str + ) -> Optional[ChatModel]: + try: + with get_db() as db: + chat = db.get(Chat, id) + chat.folder_id = folder_id + chat.updated_at = int(time.time()) + chat.pinned = False + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception: + return None + + def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: + with get_db() as db: + chat = db.get(Chat, id) + tags = chat.meta.get("tags", []) + return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] + + def get_chat_list_by_user_id_and_tag_name( + self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 + ) -> list[ChatModel]: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + tag_id = tag_name.replace(" ", "_").lower() + + print(db.bind.dialect.name) + if db.bind.dialect.name == "sqlite": + # SQLite JSON1 querying for tags within the meta JSON field + query = query.filter( + text( + f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" + ) + ).params(tag_id=tag_id) + elif db.bind.dialect.name == "postgresql": + # PostgreSQL JSON query for tags within the meta JSON field (for `json` type) + query = query.filter( + text( + "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" + ) + ).params(tag_id=tag_id) + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + all_chats = query.all() + print("all_chats", all_chats) + return [ChatModel.model_validate(chat) for chat in all_chats] + + def add_chat_tag_by_id_and_user_id_and_tag_name( + self, id: str, user_id: str, tag_name: str + ) -> Optional[ChatModel]: + tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) + if tag is None: + tag = Tags.insert_new_tag(tag_name, user_id) + try: + with get_db() as db: + chat = db.get(Chat, id) + + tag_id = tag.id + if tag_id not in chat.meta.get("tags", []): + chat.meta = { + **chat.meta, + "tags": list(set(chat.meta.get("tags", []) + [tag_id])), + } + + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception: + return None + + def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: + with get_db() as db: # Assuming `get_db()` returns a session object + query = db.query(Chat).filter_by(user_id=user_id, archived=False) + + # Normalize the tag_name for consistency + tag_id = tag_name.replace(" ", "_").lower() + + if db.bind.dialect.name == "sqlite": + # SQLite JSON1 support for querying the tags inside the `meta` JSON field + query = query.filter( + text( + f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" + ) + ).params(tag_id=tag_id) + + elif db.bind.dialect.name == "postgresql": + # PostgreSQL JSONB support for querying the tags inside the `meta` JSON field + query = query.filter( + text( + "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" + ) + ).params(tag_id=tag_id) + + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + # Get the count of matching records + count = query.count() + + # Debugging output for inspection + print(f"Count of chats for tag '{tag_name}':", count) + + return count + + def delete_tag_by_id_and_user_id_and_tag_name( + self, id: str, user_id: str, tag_name: str + ) -> bool: + try: + with get_db() as db: + chat = db.get(Chat, id) + tags = chat.meta.get("tags", []) + tag_id = tag_name.replace(" ", "_").lower() + + tags = [tag for tag in tags if tag != tag_id] + chat.meta = { + **chat.meta, + "tags": list(set(tags)), + } + db.commit() + return True + except Exception: + return False + + def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: + try: + with get_db() as db: + chat = db.get(Chat, id) + chat.meta = { + **chat.meta, + "tags": [], + } + db.commit() + + return True + except Exception: + return False + def delete_chat_by_id(self, id: str) -> bool: try: with get_db() as db: @@ -369,6 +776,18 @@ class ChatTable: except Exception: return False + def delete_chats_by_user_id_and_folder_id( + self, user_id: str, folder_id: str + ) -> bool: + try: + with get_db() as db: + db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() + db.commit() + + return True + except Exception: + return False + def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/models/feedbacks.py b/backend/open_webui/apps/webui/models/feedbacks.py new file mode 100644 index 0000000000..c2356dfd86 --- /dev/null +++ b/backend/open_webui/apps/webui/models/feedbacks.py @@ -0,0 +1,254 @@ +import logging +import time +import uuid +from typing import Optional + +from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.chats import Chats + +from open_webui.env import SRC_LOG_LEVELS +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, Text, JSON, Boolean + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +#################### +# Feedback DB Schema +#################### + + +class Feedback(Base): + __tablename__ = "feedback" + id = Column(Text, primary_key=True) + user_id = Column(Text) + version = Column(BigInteger, default=0) + type = Column(Text) + data = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + snapshot = Column(JSON, nullable=True) + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class FeedbackModel(BaseModel): + id: str + user_id: str + version: int + type: str + data: Optional[dict] = None + meta: Optional[dict] = None + snapshot: Optional[dict] = None + created_at: int + updated_at: int + + model_config = ConfigDict(from_attributes=True) + + +#################### +# Forms +#################### + + +class FeedbackResponse(BaseModel): + id: str + user_id: str + version: int + type: str + data: Optional[dict] = None + meta: Optional[dict] = None + created_at: int + updated_at: int + + +class RatingData(BaseModel): + rating: Optional[str | int] = None + model_id: Optional[str] = None + sibling_model_ids: Optional[list[str]] = None + reason: Optional[str] = None + comment: Optional[str] = None + model_config = ConfigDict(extra="allow", protected_namespaces=()) + + +class MetaData(BaseModel): + arena: Optional[bool] = None + chat_id: Optional[str] = None + message_id: Optional[str] = None + tags: Optional[list[str]] = None + model_config = ConfigDict(extra="allow") + + +class SnapshotData(BaseModel): + chat: Optional[dict] = None + model_config = ConfigDict(extra="allow") + + +class FeedbackForm(BaseModel): + type: str + data: Optional[RatingData] = None + meta: Optional[dict] = None + snapshot: Optional[SnapshotData] = None + model_config = ConfigDict(extra="allow") + + +class FeedbackTable: + def insert_new_feedback( + self, user_id: str, form_data: FeedbackForm + ) -> Optional[FeedbackModel]: + with get_db() as db: + id = str(uuid.uuid4()) + feedback = FeedbackModel( + **{ + "id": id, + "user_id": user_id, + "version": 0, + **form_data.model_dump(), + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + try: + result = Feedback(**feedback.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FeedbackModel.model_validate(result) + else: + return None + except Exception as e: + print(e) + return None + + def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: + try: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id).first() + if not feedback: + return None + return FeedbackModel.model_validate(feedback) + except Exception: + return None + + def get_feedback_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[FeedbackModel]: + try: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + if not feedback: + return None + return FeedbackModel.model_validate(feedback) + except Exception: + return None + + def get_all_feedbacks(self) -> list[FeedbackModel]: + with get_db() as db: + return [ + FeedbackModel.model_validate(feedback) + for feedback in db.query(Feedback) + .order_by(Feedback.updated_at.desc()) + .all() + ] + + def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: + with get_db() as db: + return [ + FeedbackModel.model_validate(feedback) + for feedback in db.query(Feedback) + .filter_by(type=type) + .order_by(Feedback.updated_at.desc()) + .all() + ] + + def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: + with get_db() as db: + return [ + FeedbackModel.model_validate(feedback) + for feedback in db.query(Feedback) + .filter_by(user_id=user_id) + .order_by(Feedback.updated_at.desc()) + .all() + ] + + def update_feedback_by_id( + self, id: str, form_data: FeedbackForm + ) -> Optional[FeedbackModel]: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id).first() + if not feedback: + return None + + if form_data.data: + feedback.data = form_data.data.model_dump() + if form_data.meta: + feedback.meta = form_data.meta + if form_data.snapshot: + feedback.snapshot = form_data.snapshot.model_dump() + + feedback.updated_at = int(time.time()) + + db.commit() + return FeedbackModel.model_validate(feedback) + + def update_feedback_by_id_and_user_id( + self, id: str, user_id: str, form_data: FeedbackForm + ) -> Optional[FeedbackModel]: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + if not feedback: + return None + + if form_data.data: + feedback.data = form_data.data.model_dump() + if form_data.meta: + feedback.meta = form_data.meta + if form_data.snapshot: + feedback.snapshot = form_data.snapshot.model_dump() + + feedback.updated_at = int(time.time()) + + db.commit() + return FeedbackModel.model_validate(feedback) + + def delete_feedback_by_id(self, id: str) -> bool: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id).first() + if not feedback: + return False + db.delete(feedback) + db.commit() + return True + + def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + if not feedback: + return False + db.delete(feedback) + db.commit() + return True + + def delete_feedbacks_by_user_id(self, user_id: str) -> bool: + with get_db() as db: + feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() + if not feedbacks: + return False + for feedback in feedbacks: + db.delete(feedback) + db.commit() + return True + + def delete_all_feedbacks(self) -> bool: + with get_db() as db: + feedbacks = db.query(Feedback).all() + if not feedbacks: + return False + for feedback in feedbacks: + db.delete(feedback) + db.commit() + return True + + +Feedbacks = FeedbackTable() diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/apps/webui/models/files.py index f8d4cf8e8e..f27fdd2594 100644 --- a/backend/open_webui/apps/webui/models/files.py +++ b/backend/open_webui/apps/webui/models/files.py @@ -17,14 +17,15 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) class File(Base): __tablename__ = "file" - id = Column(String, primary_key=True) user_id = Column(String) hash = Column(Text, nullable=True) filename = Column(Text) + path = Column(Text, nullable=True) + data = Column(JSON, nullable=True) - meta = Column(JSONField) + meta = Column(JSON, nullable=True) created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -38,8 +39,10 @@ class FileModel(BaseModel): hash: Optional[str] = None filename: str + path: Optional[str] = None + data: Optional[dict] = None - meta: dict + meta: Optional[dict] = None created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -50,6 +53,14 @@ class FileModel(BaseModel): #################### +class FileMeta(BaseModel): + name: Optional[str] = None + content_type: Optional[str] = None + size: Optional[int] = None + + model_config = ConfigDict(extra="allow") + + class FileModelResponse(BaseModel): id: str user_id: str @@ -57,16 +68,24 @@ class FileModelResponse(BaseModel): filename: str data: Optional[dict] = None - meta: dict + meta: FileMeta created_at: int # timestamp in epoch updated_at: int # timestamp in epoch +class FileMetadataResponse(BaseModel): + id: str + meta: dict + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + class FileForm(BaseModel): id: str hash: Optional[str] = None filename: str + path: str data: dict = {} meta: dict = {} @@ -104,6 +123,19 @@ class FilesTable: except Exception: return None + def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: + with get_db() as db: + try: + file = db.get(File, id) + return FileMetadataResponse( + id=file.id, + meta=file.meta, + created_at=file.created_at, + updated_at=file.updated_at, + ) + except Exception: + return None + def get_files(self) -> list[FileModel]: with get_db() as db: return [FileModel.model_validate(file) for file in db.query(File).all()] @@ -118,6 +150,21 @@ class FilesTable: .all() ] + def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]: + with get_db() as db: + return [ + FileMetadataResponse( + id=file.id, + meta=file.meta, + created_at=file.created_at, + updated_at=file.updated_at, + ) + for file in db.query(File) + .filter(File.id.in_(ids)) + .order_by(File.updated_at.desc()) + .all() + ] + def get_files_by_user_id(self, user_id: str) -> list[FileModel]: with get_db() as db: return [ diff --git a/backend/open_webui/apps/webui/models/folders.py b/backend/open_webui/apps/webui/models/folders.py new file mode 100644 index 0000000000..90e8880aad --- /dev/null +++ b/backend/open_webui/apps/webui/models/folders.py @@ -0,0 +1,271 @@ +import logging +import time +import uuid +from typing import Optional + +from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.chats import Chats + +from open_webui.env import SRC_LOG_LEVELS +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, Text, JSON, Boolean + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +#################### +# Folder DB Schema +#################### + + +class Folder(Base): + __tablename__ = "folder" + id = Column(Text, primary_key=True) + parent_id = Column(Text, nullable=True) + user_id = Column(Text) + name = Column(Text) + items = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + is_expanded = Column(Boolean, default=False) + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class FolderModel(BaseModel): + id: str + parent_id: Optional[str] = None + user_id: str + name: str + items: Optional[dict] = None + meta: Optional[dict] = None + is_expanded: bool = False + created_at: int + updated_at: int + + model_config = ConfigDict(from_attributes=True) + + +#################### +# Forms +#################### + + +class FolderForm(BaseModel): + name: str + model_config = ConfigDict(extra="allow") + + +class FolderTable: + def insert_new_folder( + self, user_id: str, name: str, parent_id: Optional[str] = None + ) -> Optional[FolderModel]: + with get_db() as db: + id = str(uuid.uuid4()) + folder = FolderModel( + **{ + "id": id, + "user_id": user_id, + "name": name, + "parent_id": parent_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + try: + result = Folder(**folder.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FolderModel.model_validate(result) + else: + return None + except Exception as e: + print(e) + return None + + def get_folder_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[FolderModel]: + try: + with get_db() as db: + folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + + if not folder: + return None + + return FolderModel.model_validate(folder) + except Exception: + return None + + def get_children_folders_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[FolderModel]: + try: + with get_db() as db: + folders = [] + + def get_children(folder): + children = self.get_folders_by_parent_id_and_user_id( + folder.id, user_id + ) + for child in children: + get_children(child) + folders.append(child) + + folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + if not folder: + return None + + get_children(folder) + return folders + except Exception: + return None + + def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]: + with get_db() as db: + return [ + FolderModel.model_validate(folder) + for folder in db.query(Folder).filter_by(user_id=user_id).all() + ] + + def get_folder_by_parent_id_and_user_id_and_name( + self, parent_id: Optional[str], user_id: str, name: str + ) -> Optional[FolderModel]: + try: + with get_db() as db: + # Check if folder exists + folder = ( + db.query(Folder) + .filter_by(parent_id=parent_id, user_id=user_id) + .filter(Folder.name.ilike(name)) + .first() + ) + + if not folder: + return None + + return FolderModel.model_validate(folder) + except Exception as e: + log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}") + return None + + def get_folders_by_parent_id_and_user_id( + self, parent_id: Optional[str], user_id: str + ) -> list[FolderModel]: + with get_db() as db: + return [ + FolderModel.model_validate(folder) + for folder in db.query(Folder) + .filter_by(parent_id=parent_id, user_id=user_id) + .all() + ] + + def update_folder_parent_id_by_id_and_user_id( + self, + id: str, + user_id: str, + parent_id: str, + ) -> Optional[FolderModel]: + try: + with get_db() as db: + folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + + if not folder: + return None + + folder.parent_id = parent_id + folder.updated_at = int(time.time()) + + db.commit() + + return FolderModel.model_validate(folder) + except Exception as e: + log.error(f"update_folder: {e}") + return + + def update_folder_name_by_id_and_user_id( + self, id: str, user_id: str, name: str + ) -> Optional[FolderModel]: + try: + with get_db() as db: + folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + + if not folder: + return None + + existing_folder = ( + db.query(Folder) + .filter_by(name=name, parent_id=folder.parent_id, user_id=user_id) + .first() + ) + + if existing_folder: + return None + + folder.name = name + folder.updated_at = int(time.time()) + + db.commit() + + return FolderModel.model_validate(folder) + except Exception as e: + log.error(f"update_folder: {e}") + return + + def update_folder_is_expanded_by_id_and_user_id( + self, id: str, user_id: str, is_expanded: bool + ) -> Optional[FolderModel]: + try: + with get_db() as db: + folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + + if not folder: + return None + + folder.is_expanded = is_expanded + folder.updated_at = int(time.time()) + + db.commit() + + return FolderModel.model_validate(folder) + except Exception as e: + log.error(f"update_folder: {e}") + return + + def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool: + try: + with get_db() as db: + folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + if not folder: + return False + + # Delete all chats in the folder + Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id) + + # Delete all children folders + def delete_children(folder): + folder_children = self.get_folders_by_parent_id_and_user_id( + folder.id, user_id + ) + for folder_child in folder_children: + Chats.delete_chats_by_user_id_and_folder_id( + user_id, folder_child.id + ) + delete_children(folder_child) + + folder = db.query(Folder).filter_by(id=folder_child.id).first() + db.delete(folder) + db.commit() + + delete_children(folder) + db.delete(folder) + db.commit() + return True + except Exception as e: + log.error(f"delete_folder: {e}") + return False + + +Folders = FolderTable() diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py index 698cccda0d..269ad8cc3c 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/apps/webui/models/knowledge.py @@ -6,6 +6,10 @@ import uuid from open_webui.apps.webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.files import FileMetadataResponse + + from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -64,6 +68,8 @@ class KnowledgeResponse(BaseModel): created_at: int # timestamp in epoch updated_at: int # timestamp in epoch + files: Optional[list[FileMetadataResponse | dict]] = None + class KnowledgeForm(BaseModel): name: str @@ -148,5 +154,15 @@ class KnowledgeTable: except Exception: return False + def delete_all_knowledge(self) -> bool: + with get_db() as db: + try: + db.query(Knowledge).delete() + db.commit() + + return True + except Exception: + return False + Knowledges = KnowledgeTable() diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/apps/webui/models/tags.py index 985273ff1b..7424a26604 100644 --- a/backend/open_webui/apps/webui/models/tags.py +++ b/backend/open_webui/apps/webui/models/tags.py @@ -4,53 +4,35 @@ import uuid from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db + + from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) + #################### # Tag DB Schema #################### - - class Tag(Base): __tablename__ = "tag" - - id = Column(String, primary_key=True) + id = Column(String) name = Column(String) user_id = Column(String) - data = Column(Text, nullable=True) + meta = Column(JSON, nullable=True) - -class ChatIdTag(Base): - __tablename__ = "chatidtag" - - id = Column(String, primary_key=True) - tag_name = Column(String) - chat_id = Column(String) - user_id = Column(String) - timestamp = Column(BigInteger) + # Unique constraint ensuring (id, user_id) is unique, not just the `id` column + __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),) class TagModel(BaseModel): id: str name: str user_id: str - data: Optional[str] = None - - model_config = ConfigDict(from_attributes=True) - - -class ChatIdTagModel(BaseModel): - id: str - tag_name: str - chat_id: str - user_id: str - timestamp: int - + meta: Optional[dict] = None model_config = ConfigDict(from_attributes=True) @@ -59,23 +41,15 @@ class ChatIdTagModel(BaseModel): #################### -class ChatIdTagForm(BaseModel): - tag_name: str +class TagChatIdForm(BaseModel): + name: str chat_id: str -class TagChatIdsResponse(BaseModel): - chat_ids: list[str] - - -class ChatTagsResponse(BaseModel): - tags: list[str] - - class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: with get_db() as db: - id = str(uuid.uuid4()) + id = name.replace(" ", "_").lower() tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: result = Tag(**tag.model_dump()) @@ -86,177 +60,50 @@ class TagTable: return TagModel.model_validate(result) else: return None - except Exception: + except Exception as e: + print(e) return None def get_tag_by_name_and_user_id( self, name: str, user_id: str ) -> Optional[TagModel]: try: + id = name.replace(" ", "_").lower() with get_db() as db: - tag = db.query(Tag).filter_by(name=name, user_id=user_id).first() + tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def add_tag_to_chat( - self, user_id: str, form_data: ChatIdTagForm - ) -> Optional[ChatIdTagModel]: - tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) - if tag is None: - tag = self.insert_new_tag(form_data.tag_name, user_id) - - id = str(uuid.uuid4()) - chatIdTag = ChatIdTagModel( - **{ - "id": id, - "user_id": user_id, - "chat_id": form_data.chat_id, - "tag_name": tag.name, - "timestamp": int(time.time()), - } - ) - try: - with get_db() as db: - result = ChatIdTag(**chatIdTag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None - except Exception: - return None - def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: with get_db() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - return [ TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) + for tag in (db.query(Tag).filter_by(user_id=user_id).all()) ] - def get_tags_by_chat_id_and_user_id( - self, chat_id: str, user_id: str + def get_tags_by_ids_and_user_id( + self, ids: list[str], user_id: str ) -> list[TagModel]: with get_db() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - return [ TagModel.model_validate(tag) for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() + db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all() ) ] - def get_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> list[ChatIdTagModel]: - with get_db() as db: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - - def count_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> int: - with get_db() as db: - return ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) - - def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: + def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: try: with get_db() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) + id = name.replace(" ", "_").lower() + res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() log.debug(f"res: {res}") db.commit() - - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - db.commit() return True except Exception as e: log.error(f"delete_tag: {e}") return False - def delete_tag_by_tag_name_and_chat_id_and_user_id( - self, tag_name: str, chat_id: str, user_id: str - ) -> bool: - try: - with get_db() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - db.commit() - - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - db.commit() - - return True - except Exception as e: - log.error(f"delete_tag: {e}") - return False - - def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: - tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) - - for tag in tags: - self.delete_tag_by_tag_name_and_chat_id_and_user_id( - tag.tag_name, chat_id, user_id - ) - - return True - Tags = TagTable() diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index 68b86e7925..ef0a0d445b 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -1,10 +1,13 @@ import re import uuid +import time +import datetime from open_webui.apps.webui.models.auths import ( AddUserForm, ApiKey, Auths, + Token, SigninForm, SigninResponse, SignupForm, @@ -18,6 +21,8 @@ from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, ) from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import Response @@ -32,6 +37,7 @@ from open_webui.utils.utils import ( get_password_hash, ) from open_webui.utils.webhook import post_webhook +from typing import Optional router = APIRouter() @@ -40,23 +46,44 @@ router = APIRouter() ############################ -@router.get("/", response_model=UserResponse) +class SessionUserResponse(Token, UserResponse): + expires_at: Optional[int] = None + + +@router.get("/", response_model=SessionUserResponse) async def get_session_user( request: Request, response: Response, user=Depends(get_current_user) ): + expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) + expires_at = None + if expires_delta: + expires_at = int(time.time()) + int(expires_delta.total_seconds()) + token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), + expires_delta=expires_delta, + ) + + datetime_expires_at = ( + datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) + if expires_at + else None ) # Set the cookie token response.set_cookie( key="token", value=token, + expires=datetime_expires_at, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, ) return { + "token": token, + "token_type": "Bearer", + "expires_at": expires_at, "id": user.id, "email": user.email, "name": user.name, @@ -115,7 +142,7 @@ async def update_password( ############################ -@router.post("/signin", response_model=SigninResponse) +@router.post("/signin", response_model=SessionUserResponse) async def signin(request: Request, response: Response, form_data: SigninForm): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: @@ -157,21 +184,37 @@ async def signin(request: Request, response: Response, form_data: SigninForm): user = Auths.authenticate_user(form_data.email.lower(), form_data.password) if user: + + expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) + expires_at = None + if expires_delta: + expires_at = int(time.time()) + int(expires_delta.total_seconds()) + token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), + expires_delta=expires_delta, + ) + + datetime_expires_at = ( + datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) + if expires_at + else None ) # Set the cookie token response.set_cookie( key="token", value=token, + expires=datetime_expires_at, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, ) return { "token": token, "token_type": "Bearer", + "expires_at": expires_at, "id": user.id, "email": user.email, "name": user.name, @@ -187,7 +230,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): ############################ -@router.post("/signup", response_model=SigninResponse) +@router.post("/signup", response_model=SessionUserResponse) async def signup(request: Request, response: Response, form_data: SignupForm): if WEBUI_AUTH: if ( @@ -227,16 +270,30 @@ async def signup(request: Request, response: Response, form_data: SignupForm): ) if user: + expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) + expires_at = None + if expires_delta: + expires_at = int(time.time()) + int(expires_delta.total_seconds()) + token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), + expires_delta=expires_delta, + ) + + datetime_expires_at = ( + datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) + if expires_at + else None ) # Set the cookie token response.set_cookie( key="token", value=token, + expires=datetime_expires_at, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, ) if request.app.state.config.WEBHOOK_URL: @@ -253,6 +310,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): return { "token": token, "token_type": "Bearer", + "expires_at": expires_at, "id": user.id, "email": user.email, "name": user.name, @@ -265,6 +323,12 @@ async def signup(request: Request, response: Response, form_data: SignupForm): raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) +@router.get("/signout") +async def signout(response: Response): + response.delete_cookie("token") + return {"status": True} + + ############################ # AddUser ############################ diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index ca7e95baf4..b149b2eb48 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -4,16 +4,14 @@ from typing import Optional from open_webui.apps.webui.models.chats import ( ChatForm, + ChatImportForm, ChatResponse, Chats, ChatTitleIdResponse, ) -from open_webui.apps.webui.models.tags import ( - ChatIdTagForm, - ChatIdTagModel, - TagModel, - Tags, -) +from open_webui.apps.webui.models.tags import TagModel, Tags +from open_webui.apps.webui.models.folders import Folders + from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS @@ -95,7 +93,35 @@ async def get_user_chat_list_by_user_id( async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): try: chat = Chats.insert_new_chat(user.id, form_data) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# ImportChat +############################ + + +@router.post("/import", response_model=Optional[ChatResponse]) +async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)): + try: + chat = Chats.import_chat(user.id, form_data) + if chat: + tags = chat.meta.get("tags", []) + for tag_id in tags: + tag_id = tag_id.replace(" ", "_").lower() + tag_name = " ".join([word.capitalize() for word in tag_id.split("_")]) + if ( + tag_id != "none" + and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None + ): + Tags.insert_new_tag(tag_name, user.id) + + return ChatResponse(**chat.model_dump()) except Exception as e: log.exception(e) raise HTTPException( @@ -108,10 +134,77 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): ############################ +@router.get("/search", response_model=list[ChatTitleIdResponse]) +async def search_user_chats( + text: str, page: Optional[int] = None, user=Depends(get_verified_user) +): + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + chat_list = [ + ChatTitleIdResponse(**chat.model_dump()) + for chat in Chats.get_chats_by_user_id_and_search_text( + user.id, text, skip=skip, limit=limit + ) + ] + + # Delete tag if no chat is found + words = text.strip().split(" ") + if page == 1 and len(words) == 1 and words[0].startswith("tag:"): + tag_id = words[0].replace("tag:", "") + if len(chat_list) == 0: + if Tags.get_tag_by_name_and_user_id(tag_id, user.id): + log.debug(f"deleting tag: {tag_id}") + Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + + return chat_list + + +############################ +# GetChatsByFolderId +############################ + + +@router.get("/folder/{folder_id}", response_model=list[ChatResponse]) +async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)): + folder_ids = [folder_id] + children_folders = Folders.get_children_folders_by_id_and_user_id( + folder_id, user.id + ) + if children_folders: + folder_ids.extend([folder.id for folder in children_folders]) + + return [ + ChatResponse(**chat.model_dump()) + for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id) + ] + + +############################ +# GetPinnedChats +############################ + + +@router.get("/pinned", response_model=list[ChatResponse]) +async def get_user_pinned_chats(user=Depends(get_verified_user)): + return [ + ChatResponse(**chat.model_dump()) + for chat in Chats.get_pinned_chats_by_user_id(user.id) + ] + + +############################ +# GetChats +############################ + + @router.get("/all", response_model=list[ChatResponse]) async def get_user_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + ChatResponse(**chat.model_dump()) for chat in Chats.get_chats_by_user_id(user.id) ] @@ -124,11 +217,28 @@ async def get_user_chats(user=Depends(get_verified_user)): @router.get("/all/archived", response_model=list[ChatResponse]) async def get_user_archived_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + ChatResponse(**chat.model_dump()) for chat in Chats.get_archived_chats_by_user_id(user.id) ] +############################ +# GetAllTags +############################ + + +@router.get("/all/tags", response_model=list[TagModel]) +async def get_all_user_tags(user=Depends(get_verified_user)): + try: + tags = Tags.get_tags_by_user_id(user.id) + return tags + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetAllChatsInDB ############################ @@ -141,10 +251,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_chats() - ] + return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] ############################ @@ -187,7 +294,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id(share_id) if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -199,48 +307,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): ############################ -class TagNameForm(BaseModel): +class TagForm(BaseModel): name: str + + +class TagFilterForm(TagForm): skip: Optional[int] = 0 limit: Optional[int] = 50 @router.post("/tags", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagNameForm, user=Depends(get_verified_user) + form_data: TagFilterForm, user=Depends(get_verified_user) ): - chat_ids = [ - chat_id_tag.chat_id - for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( - form_data.name, user.id - ) - ] - - chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) - + chats = Chats.get_chat_list_by_user_id_and_tag_name( + user.id, form_data.name, form_data.skip, form_data.limit + ) if len(chats) == 0: - Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) return chats -############################ -# GetAllTags -############################ - - -@router.get("/tags/all", response_model=list[TagModel]) -async def get_all_tags(user=Depends(get_verified_user)): - try: - tags = Tags.get_tags_by_user_id(user.id) - return tags - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) - - ############################ # GetChatById ############################ @@ -251,7 +339,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -269,10 +358,9 @@ async def update_chat_by_id( ): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - updated_chat = {**json.loads(chat.chat), **form_data.chat} - + updated_chat = {**chat.chat, **form_data.chat} chat = Chats.update_chat_by_id(id, updated_chat) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -288,7 +376,13 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): if user.role == "admin": + chat = Chats.get_chat_by_id(id) + for tag in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + result = Chats.delete_chat_by_id(id) + return result else: if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get( @@ -299,29 +393,66 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + chat = Chats.get_chat_by_id(id) + for tag in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + result = Chats.delete_chat_by_id_and_user_id(id, user.id) return result +############################ +# GetPinnedStatusById +############################ + + +@router.get("/{id}/pinned", response_model=Optional[bool]) +async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + return chat.pinned + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# PinChatById +############################ + + +@router.post("/{id}/pin", response_model=Optional[ChatResponse]) +async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + chat = Chats.toggle_chat_pinned_by_id(id) + return chat + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # CloneChat ############################ -@router.get("/{id}/clone", response_model=Optional[ChatResponse]) +@router.post("/{id}/clone", response_model=Optional[ChatResponse]) async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat_body = json.loads(chat.chat) updated_chat = { - **chat_body, + **chat.chat, "originalChatId": chat.id, - "branchPointMessageId": chat_body["history"]["currentId"], + "branchPointMessageId": chat.chat["history"]["currentId"], "title": f"Clone of {chat.title}", } chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -333,12 +464,26 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.get("/{id}/archive", response_model=Optional[ChatResponse]) +@router.post("/{id}/archive", response_model=Optional[ChatResponse]) async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat = Chats.toggle_chat_archive_by_id(id) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + + # Delete tags if chat is archived + if chat.archived: + for tag_id in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0: + log.debug(f"deleting tag: {tag_id}") + Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + else: + for tag_id in chat.meta.get("tags", []): + tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id) + if tag is None: + log.debug(f"inserting tag: {tag_id}") + tag = Tags.insert_new_tag(tag_id, user.id) + + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -356,9 +501,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)): if chat: if chat.share_id: shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) - return ChatResponse( - **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} - ) + return ChatResponse(**shared_chat.model_dump()) shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) if not shared_chat: @@ -366,10 +509,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(), ) + return ChatResponse(**shared_chat.model_dump()) - return ChatResponse( - **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} - ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -400,6 +541,31 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): ) +############################ +# UpdateChatFolderIdById +############################ + + +class ChatFolderIdForm(BaseModel): + folder_id: Optional[str] = None + + +@router.post("/{id}/folder", response_model=Optional[ChatResponse]) +async def update_chat_folder_id_by_id( + id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + chat = Chats.update_chat_folder_id_by_id_and_user_id( + id, user.id, form_data.folder_id + ) + return ChatResponse(**chat.model_dump()) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetChatTagsById ############################ @@ -407,10 +573,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/tags", response_model=list[TagModel]) async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) - - if tags != None: - return tags + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -422,22 +588,30 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) -async def add_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) +@router.post("/{id}/tags", response_model=list[TagModel]) +async def add_tag_by_id_and_tag_name( + id: str, form_data: TagForm, user=Depends(get_verified_user) ): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + tags = chat.meta.get("tags", []) + tag_id = form_data.name.replace(" ", "_").lower() - if form_data.tag_name not in tags: - tag = Tags.add_tag_to_chat(user.id, form_data) - - if tag: - return tag - else: + if tag_id == "none": raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"), ) + + print(tags, tag_id) + if tag_id not in tags: + Chats.add_chat_tag_by_id_and_user_id_and_tag_name( + id, user.id, form_data.name + ) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -449,16 +623,20 @@ async def add_chat_tag_by_id( ############################ -@router.delete("/{id}/tags", response_model=Optional[bool]) -async def delete_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) +@router.delete("/{id}/tags", response_model=list[TagModel]) +async def delete_tag_by_id_and_tag_name( + id: str, form_data: TagForm, user=Depends(get_verified_user) ): - result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( - form_data.tag_name, id, user.id - ) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) - if result: - return result + if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -466,16 +644,21 @@ async def delete_chat_tag_by_id( ############################ -# DeleteAllChatTagsById +# DeleteAllTagsById ############################ @router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) +async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + Chats.delete_all_tags_by_id_and_user_id(id, user.id) - if result: - return result + for tag in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + + return True else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND diff --git a/backend/open_webui/apps/webui/routers/evaluations.py b/backend/open_webui/apps/webui/routers/evaluations.py new file mode 100644 index 0000000000..9a6845efac --- /dev/null +++ b/backend/open_webui/apps/webui/routers/evaluations.py @@ -0,0 +1,147 @@ +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, status, Request +from pydantic import BaseModel + +from open_webui.apps.webui.models.users import Users, UserModel +from open_webui.apps.webui.models.feedbacks import ( + FeedbackModel, + FeedbackForm, + Feedbacks, +) + +from open_webui.constants import ERROR_MESSAGES +from open_webui.utils.utils import get_admin_user, get_verified_user + +router = APIRouter() + + +############################ +# GetConfig +############################ + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return { + "ENABLE_EVALUATION_ARENA_MODELS": request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS, + "EVALUATION_ARENA_MODELS": request.app.state.config.EVALUATION_ARENA_MODELS, + } + + +############################ +# UpdateConfig +############################ + + +class UpdateConfigForm(BaseModel): + ENABLE_EVALUATION_ARENA_MODELS: Optional[bool] = None + EVALUATION_ARENA_MODELS: Optional[list[dict]] = None + + +@router.post("/config") +async def update_config( + request: Request, + form_data: UpdateConfigForm, + user=Depends(get_admin_user), +): + config = request.app.state.config + if form_data.ENABLE_EVALUATION_ARENA_MODELS is not None: + config.ENABLE_EVALUATION_ARENA_MODELS = form_data.ENABLE_EVALUATION_ARENA_MODELS + if form_data.EVALUATION_ARENA_MODELS is not None: + config.EVALUATION_ARENA_MODELS = form_data.EVALUATION_ARENA_MODELS + return { + "ENABLE_EVALUATION_ARENA_MODELS": config.ENABLE_EVALUATION_ARENA_MODELS, + "EVALUATION_ARENA_MODELS": config.EVALUATION_ARENA_MODELS, + } + + +@router.get("/feedbacks", response_model=list[FeedbackModel]) +async def get_feedbacks(user=Depends(get_verified_user)): + feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id) + return feedbacks + + +@router.delete("/feedbacks", response_model=bool) +async def delete_feedbacks(user=Depends(get_verified_user)): + success = Feedbacks.delete_feedbacks_by_user_id(user.id) + return success + + +class FeedbackUserModel(FeedbackModel): + user: Optional[UserModel] = None + + +@router.get("/feedbacks/all", response_model=list[FeedbackUserModel]) +async def get_all_feedbacks(user=Depends(get_admin_user)): + feedbacks = Feedbacks.get_all_feedbacks() + return [ + FeedbackUserModel( + **feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id) + ) + for feedback in feedbacks + ] + + +@router.delete("/feedbacks/all") +async def delete_all_feedbacks(user=Depends(get_admin_user)): + success = Feedbacks.delete_all_feedbacks() + return success + + +@router.post("/feedback", response_model=FeedbackModel) +async def create_feedback( + request: Request, + form_data: FeedbackForm, + user=Depends(get_verified_user), +): + feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) + if not feedback: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + return feedback + + +@router.get("/feedback/{id}", response_model=FeedbackModel) +async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): + feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) + + if not feedback: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + return feedback + + +@router.post("/feedback/{id}", response_model=FeedbackModel) +async def update_feedback_by_id( + id: str, form_data: FeedbackForm, user=Depends(get_verified_user) +): + feedback = Feedbacks.update_feedback_by_id_and_user_id( + id=id, user_id=user.id, form_data=form_data + ) + + if not feedback: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + return feedback + + +@router.delete("/feedback/{id}") +async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)): + if user.role == "admin": + success = Feedbacks.delete_feedback_by_id(id=id) + else: + success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id) + + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + return success diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index 2761d2b12a..8294328cfb 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -1,14 +1,19 @@ import logging import os -import shutil import uuid from pathlib import Path from typing import Optional from pydantic import BaseModel import mimetypes +from open_webui.storage.provider import Storage -from open_webui.apps.webui.models.files import FileForm, FileModel, Files +from open_webui.apps.webui.models.files import ( + FileForm, + FileModel, + FileModelResponse, + Files, +) from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.config import UPLOAD_DIR @@ -44,24 +49,19 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): id = str(uuid.uuid4()) name = filename filename = f"{id}_{filename}" - file_path = f"{UPLOAD_DIR}/{filename}" + contents, file_path = Storage.upload_file(file.file, filename) - contents = file.file.read() - with open(file_path, "wb") as f: - f.write(contents) - f.close() - - file = Files.insert_new_file( + file_item = Files.insert_new_file( user.id, FileForm( **{ "id": id, "filename": filename, + "path": file_path, "meta": { "name": name, "content_type": file.content_type, "size": len(contents), - "path": file_path, }, } ), @@ -69,13 +69,13 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): try: process_file(ProcessFileForm(file_id=id)) - file = Files.get_file_by_id(id=id) + file_item = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) - log.error(f"Error processing file: {file.id}") + log.error(f"Error processing file: {file_item.id}") - if file: - return file + if file_item: + return file_item else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -95,7 +95,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ############################ -@router.get("/", response_model=list[FileModel]) +@router.get("/", response_model=list[FileModelResponse]) async def list_files(user=Depends(get_verified_user)): if user.role == "admin": files = Files.get_files() @@ -112,27 +112,16 @@ async def list_files(user=Depends(get_verified_user)): @router.delete("/all") async def delete_all_files(user=Depends(get_admin_user)): result = Files.delete_all_files() - if result: - folder = f"{UPLOAD_DIR}" try: - # Check if the directory exists - if os.path.exists(folder): - # Iterate over all the files and directories in the specified directory - for filename in os.listdir(folder): - file_path = os.path.join(folder, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) # Remove the file or link - elif os.path.isdir(file_path): - shutil.rmtree(file_path) # Remove the directory - except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") - else: - print(f"The directory {folder} does not exist") + Storage.delete_all_files() except Exception as e: - print(f"Failed to process the directory {folder}. Reason: {e}") - + log.exception(e) + log.error(f"Error deleting files") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), + ) return {"message": "All files deleted successfully"} else: raise HTTPException( @@ -213,24 +202,32 @@ async def update_file_data_content_by_id( ############################ -@router.get("/{id}/content", response_model=Optional[FileModel]) +@router.get("/{id}/content") async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) - if file and (file.user_id == user.id or user.role == "admin"): - file_path = Path(file.meta["path"]) + try: + file_path = Storage.get_file(file.path) + file_path = Path(file_path) - # Check if the file already exists in the cache - if file_path.is_file(): - print(f"file_path: {file_path}") - headers = { - "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' - } - return FileResponse(file_path, headers=headers) - else: + # Check if the file already exists in the cache + if file_path.is_file(): + print(f"file_path: {file_path}") + headers = { + "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' + } + return FileResponse(file_path, headers=headers) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + except Exception as e: + log.exception(e) + log.error(f"Error getting file content") raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=ERROR_MESSAGES.NOT_FOUND, + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), ) else: raise HTTPException( @@ -239,19 +236,23 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): ) -@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel]) +@router.get("/{id}/content/{file_name}") async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) if file and (file.user_id == user.id or user.role == "admin"): - file_path = file.meta.get("path") + file_path = file.path if file_path: + file_path = Storage.get_file(file_path) file_path = Path(file_path) # Check if the file already exists in the cache if file_path.is_file(): print(f"file_path: {file_path}") - return FileResponse(file_path) + headers = { + "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' + } + return FileResponse(file_path, headers=headers) else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -289,6 +290,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): if file and (file.user_id == user.id or user.role == "admin"): result = Files.delete_file_by_id(id) if result: + try: + Storage.delete_file(file.filename) + except Exception as e: + log.exception(e) + log.error(f"Error deleting files") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), + ) return {"message": "File deleted successfully"} else: raise HTTPException( diff --git a/backend/open_webui/apps/webui/routers/folders.py b/backend/open_webui/apps/webui/routers/folders.py new file mode 100644 index 0000000000..36075c357b --- /dev/null +++ b/backend/open_webui/apps/webui/routers/folders.py @@ -0,0 +1,251 @@ +import logging +import os +import shutil +import uuid +from pathlib import Path +from typing import Optional +from pydantic import BaseModel +import mimetypes + + +from open_webui.apps.webui.models.folders import ( + FolderForm, + FolderModel, + Folders, +) +from open_webui.apps.webui.models.chats import Chats + +from open_webui.config import UPLOAD_DIR +from open_webui.env import SRC_LOG_LEVELS +from open_webui.constants import ERROR_MESSAGES + + +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi.responses import FileResponse, StreamingResponse + + +from open_webui.utils.utils import get_admin_user, get_verified_user + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +router = APIRouter() + + +############################ +# Get Folders +############################ + + +@router.get("/", response_model=list[FolderModel]) +async def get_folders(user=Depends(get_verified_user)): + folders = Folders.get_folders_by_user_id(user.id) + + return [ + { + **folder.model_dump(), + "items": { + "chats": [ + {"title": chat.title, "id": chat.id} + for chat in Chats.get_chats_by_folder_id_and_user_id( + folder.id, user.id + ) + ] + }, + } + for folder in folders + ] + + +############################ +# Create Folder +############################ + + +@router.post("/") +def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): + folder = Folders.get_folder_by_parent_id_and_user_id_and_name( + None, user.id, form_data.name + ) + + if folder: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), + ) + + try: + folder = Folders.insert_new_folder(user.id, form_data.name) + return folder + except Exception as e: + log.exception(e) + log.error("Error creating folder") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating folder"), + ) + + +############################ +# Get Folders By Id +############################ + + +@router.get("/{id}", response_model=Optional[FolderModel]) +async def get_folder_by_id(id: str, user=Depends(get_verified_user)): + folder = Folders.get_folder_by_id_and_user_id(id, user.id) + if folder: + return folder + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# Update Folder Name By Id +############################ + + +@router.post("/{id}/update") +async def update_folder_name_by_id( + id: str, form_data: FolderForm, user=Depends(get_verified_user) +): + folder = Folders.get_folder_by_id_and_user_id(id, user.id) + if folder: + existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( + folder.parent_id, user.id, form_data.name + ) + if existing_folder: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), + ) + + try: + folder = Folders.update_folder_name_by_id_and_user_id( + id, user.id, form_data.name + ) + + return folder + except Exception as e: + log.exception(e) + log.error(f"Error updating folder: {id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating folder"), + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# Update Folder Parent Id By Id +############################ + + +class FolderParentIdForm(BaseModel): + parent_id: Optional[str] = None + + +@router.post("/{id}/update/parent") +async def update_folder_parent_id_by_id( + id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user) +): + folder = Folders.get_folder_by_id_and_user_id(id, user.id) + if folder: + existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( + form_data.parent_id, user.id, folder.name + ) + + if existing_folder: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), + ) + + try: + folder = Folders.update_folder_parent_id_by_id_and_user_id( + id, user.id, form_data.parent_id + ) + return folder + except Exception as e: + log.exception(e) + log.error(f"Error updating folder: {id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating folder"), + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# Update Folder Is Expanded By Id +############################ + + +class FolderIsExpandedForm(BaseModel): + is_expanded: bool + + +@router.post("/{id}/update/expanded") +async def update_folder_is_expanded_by_id( + id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user) +): + folder = Folders.get_folder_by_id_and_user_id(id, user.id) + if folder: + try: + folder = Folders.update_folder_is_expanded_by_id_and_user_id( + id, user.id, form_data.is_expanded + ) + return folder + except Exception as e: + log.exception(e) + log.error(f"Error updating folder: {id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating folder"), + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# Delete Folder By Id +############################ + + +@router.delete("/{id}") +async def delete_folder_by_id(id: str, user=Depends(get_verified_user)): + folder = Folders.get_folder_by_id_and_user_id(id, user.id) + if folder: + try: + result = Folders.delete_folder_by_id_and_user_id(id, user.id) + if result: + return result + else: + raise Exception("Error deleting folder") + except Exception as e: + log.exception(e) + log.error(f"Error deleting folder: {id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"), + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) diff --git a/backend/open_webui/apps/webui/routers/functions.py b/backend/open_webui/apps/webui/routers/functions.py index 1306034625..aeaceecfb1 100644 --- a/backend/open_webui/apps/webui/routers/functions.py +++ b/backend/open_webui/apps/webui/routers/functions.py @@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import ( Functions, ) from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports -from open_webui.config import CACHE_DIR, FUNCTIONS_DIR +from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.utils import get_admin_user, get_verified_user diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index a792c24fa3..9cb38a8214 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -48,7 +48,12 @@ async def get_knowledge_items( ) else: return [ - KnowledgeResponse(**knowledge.model_dump()) + KnowledgeResponse( + **knowledge.model_dump(), + files=Files.get_file_metadatas_by_ids( + knowledge.data.get("file_ids", []) if knowledge.data else [] + ), + ) for knowledge in Knowledges.get_knowledge_items() ] diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index 0db21c8953..d1ad89deae 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs from open_webui.utils.utils import get_admin_user, get_verified_user -TOOLS_DIR = f"{DATA_DIR}/tools" -os.makedirs(TOOLS_DIR, exist_ok=True) - router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/utils.py b/backend/open_webui/apps/webui/routers/utils.py index 82c294bd7b..0ab0f6b156 100644 --- a/backend/open_webui/apps/webui/routers/utils.py +++ b/backend/open_webui/apps/webui/routers/utils.py @@ -1,16 +1,14 @@ -import site -from pathlib import Path - import black import markdown + +from open_webui.apps.webui.models.chats import ChatTitleMessagesForm from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT -from open_webui.env import FONTS_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Response, status -from fpdf import FPDF from pydantic import BaseModel from starlette.responses import FileResponse from open_webui.utils.misc import get_gravatar_url +from open_webui.utils.pdf_generator import PDFGenerator from open_webui.utils.utils import get_admin_user router = APIRouter() @@ -56,58 +54,19 @@ class ChatForm(BaseModel): @router.post("/pdf") async def download_chat_as_pdf( - form_data: ChatForm, + form_data: ChatTitleMessagesForm, ): - global FONTS_DIR + try: + pdf_bytes = PDFGenerator(form_data).generate_chat_pdf() - pdf = FPDF() - pdf.add_page() - - # When running using `pip install` the static directory is in the site packages. - if not FONTS_DIR.exists(): - FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts" - # When running using `pip install -e .` the static directory is in the site packages. - # This path only works if `open-webui serve` is run from the root of this project. - if not FONTS_DIR.exists(): - FONTS_DIR = Path("./backend/static/fonts") - - pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf") - pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf") - pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf") - pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf") - pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf") - pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf") - - pdf.set_font("NotoSans", size=12) - pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"]) - - pdf.set_auto_page_break(auto=True, margin=15) - - # Adjust the effective page width for multi_cell - effective_page_width = ( - pdf.w - 2 * pdf.l_margin - 10 - ) # Subtracted an additional 10 for extra padding - - # Add chat messages - for message in form_data.messages: - role = message["role"] - content = message["content"] - pdf.set_font("NotoSans", "B", size=14) # Bold for the role - pdf.multi_cell(effective_page_width, 10, f"{role.upper()}", 0, "L") - pdf.ln(1) # Extra space between messages - - pdf.set_font("NotoSans", size=10) # Regular for content - pdf.multi_cell(effective_page_width, 6, content, 0, "L") - pdf.ln(1.5) # Extra space between messages - - # Save the pdf with name .pdf - pdf_bytes = pdf.output() - - return Response( - content=bytes(pdf_bytes), - media_type="application/pdf", - headers={"Content-Disposition": "attachment;filename=chat.pdf"}, - ) + return Response( + content=pdf_bytes, + media_type="application/pdf", + headers={"Content-Disposition": "attachment;filename=chat.pdf"}, + ) + except Exception as e: + print(e) + raise HTTPException(status_code=400, detail=str(e)) @router.get("/db/download") diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/apps/webui/utils.py index 969d5622c8..51d3796568 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/apps/webui/utils.py @@ -8,7 +8,6 @@ import tempfile from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.tools import Tools -from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR def extract_frontmatter(content): diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index bfc9a4ded5..9d1bd72d89 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig( ) OAUTH_PICTURE_CLAIM = PersistentConfig( - "OAUTH_USERNAME_CLAIM", + "OAUTH_PICTURE_CLAIM", "oauth.oidc.avatar_claim", os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), ) @@ -394,6 +394,33 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) +ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( + "ENABLE_OAUTH_ROLE_MANAGEMENT", + "oauth.enable_role_mapping", + os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", +) + +OAUTH_ROLES_CLAIM = PersistentConfig( + "OAUTH_ROLES_CLAIM", + "oauth.roles_claim", + os.environ.get("OAUTH_ROLES_CLAIM", "roles"), +) + +OAUTH_ALLOWED_ROLES = PersistentConfig( + "OAUTH_ALLOWED_ROLES", + "oauth.allowed_roles", + [ + role.strip() + for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",") + ], +) + +OAUTH_ADMIN_ROLES = PersistentConfig( + "OAUTH_ADMIN_ROLES", + "oauth.admin_roles", + [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() @@ -506,6 +533,18 @@ if CUSTOM_NAME: pass +#################################### +# STORAGE PROVIDER +#################################### + +STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3 + +S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None) +S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None) +S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None) +S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None) +S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) + #################################### # File Upload DIR #################################### @@ -521,26 +560,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) -#################################### -# Tools DIR -#################################### - -TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") -Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) - - -#################################### -# Functions DIR -#################################### - -FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") -Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) - #################################### # OLLAMA_BASE_URL #################################### - ENABLE_OLLAMA_API = PersistentConfig( "ENABLE_OLLAMA_API", "ollama.enable", @@ -728,6 +751,28 @@ USER_PERMISSIONS = PersistentConfig( }, ) + +ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig( + "ENABLE_EVALUATION_ARENA_MODELS", + "evaluation.arena.enable", + os.environ.get("ENABLE_EVALUATION_ARENA_MODELS", "True").lower() == "true", +) +EVALUATION_ARENA_MODELS = PersistentConfig( + "EVALUATION_ARENA_MODELS", + "evaluation.arena.models", + [], +) + +DEFAULT_ARENA_MODEL = { + "id": "arena-model", + "name": "Arena Model", + "meta": { + "profile_image_url": "/favicon.png", + "description": "Submit your questions to anonymous AI chatbots and vote on the best response.", + "model_ids": None, + }, +} + ENABLE_MODEL_FILTER = PersistentConfig( "ENABLE_MODEL_FILTER", "model_filter.enable", @@ -853,6 +898,12 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), ) +TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "TAGS_GENERATION_PROMPT_TEMPLATE", + "task.tags.prompt_template", + os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""), +) + ENABLE_SEARCH_QUERY = PersistentConfig( "ENABLE_SEARCH_QUERY", "task.search.enable", @@ -901,6 +952,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") +# Qdrant +QDRANT_URI = os.environ.get("QDRANT_URI", None) + #################################### # Information Retrieval (RAG) #################################### @@ -986,10 +1040,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig( - "RAG_EMBEDDING_OPENAI_BATCH_SIZE", - "rag.embedding_openai_batch_size", - int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")), +RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( + "RAG_EMBEDDING_BATCH_SIZE", + "rag.embedding_batch_size", + int( + os.environ.get("RAG_EMBEDDING_BATCH_SIZE") + or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1") + ), ) RAG_RERANKING_MODEL = PersistentConfig( @@ -1008,6 +1065,22 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) + +RAG_TEXT_SPLITTER = PersistentConfig( + "RAG_TEXT_SPLITTER", + "rag.text_splitter", + os.environ.get("RAG_TEXT_SPLITTER", ""), +) + + +TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") +TIKTOKEN_ENCODING_NAME = PersistentConfig( + "TIKTOKEN_ENCODING_NAME", + "rag.tiktoken_encoding_name", + os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"), +) + + CHUNK_SIZE = PersistentConfig( "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) ) @@ -1020,7 +1093,7 @@ CHUNK_OVERLAP = PersistentConfig( DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules. -[context] +{{CONTEXT}} @@ -1033,7 +1106,7 @@ DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and r -[query] +{{QUERY}} """ @@ -1168,17 +1241,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( ) -#################################### -# Transcribe -#################################### - -WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") -WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") -WHISPER_MODEL_AUTO_UPDATE = ( - os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" -) - - #################################### # Images #################################### @@ -1394,6 +1456,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig( # Audio #################################### +# Transcription +WHISPER_MODEL = PersistentConfig( + "WHISPER_MODEL", + "audio.stt.whisper_model", + os.getenv("WHISPER_MODEL", "base"), +) + +WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") +WHISPER_MODEL_AUTO_UPDATE = ( + os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" +) + + AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( "AUDIO_STT_OPENAI_API_BASE_URL", "audio.stt.openai.api_base_url", @@ -1415,7 +1490,7 @@ AUDIO_STT_ENGINE = PersistentConfig( AUDIO_STT_MODEL = PersistentConfig( "AUDIO_STT_MODEL", "audio.stt.model", - os.getenv("AUDIO_STT_MODEL", "whisper-1"), + os.getenv("AUDIO_STT_MODEL", ""), ) AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 37461402b6..d6f33af4a3 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -20,7 +20,9 @@ class ERROR_MESSAGES(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}" + DEFAULT = ( + lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}' + ) ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot." @@ -106,6 +108,7 @@ class TASKS(str, Enum): DEFAULT = lambda task="": f"{task if task else 'generation'}" TITLE_GENERATION = "title_generation" + TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" FUNCTION_CALLING = "function_calling" diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index fbf22d84d2..4b61e1a894 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -230,6 +230,8 @@ if FROM_INIT_PY: DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) +STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")) + FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts")) FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() @@ -302,6 +304,12 @@ RESET_CONFIG_ON_START = ( os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" ) +#################################### +# REDIS +#################################### + +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + #################################### # WEBUI_AUTH (Required for security) #################################### @@ -343,8 +351,7 @@ ENABLE_WEBSOCKET_SUPPORT = ( WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") -WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0") - +WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") @@ -355,3 +362,23 @@ else: AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) except Exception: AIOHTTP_CLIENT_TIMEOUT = 300 + +AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( + "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3" +) + +if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None +else: + try: + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int( + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST + ) + except Exception: + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3 + +#################################### +# OFFLINE_MODE +#################################### + +OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7086a3cc9a..1c7e5dd214 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1,4 +1,4 @@ -import base64 +import asyncio import inspect import json import logging @@ -7,20 +7,38 @@ import os import shutil import sys import time -import uuid -import asyncio - +import random from contextlib import asynccontextmanager from typing import Optional import aiohttp import requests +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel +from sqlalchemy import text +from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import Response, StreamingResponse +from open_webui.apps.audio.main import app as audio_app +from open_webui.apps.images.main import app as images_app from open_webui.apps.ollama.main import ( app as ollama_app, get_all_models as get_ollama_models, generate_chat_completion as generate_ollama_chat_completion, - generate_openai_chat_completion as generate_ollama_openai_chat_completion, GenerateChatCompletionForm, ) from open_webui.apps.openai.main import ( @@ -28,38 +46,24 @@ from open_webui.apps.openai.main import ( generate_chat_completion as generate_openai_chat_completion, get_all_models as get_openai_models, ) - from open_webui.apps.retrieval.main import app as retrieval_app from open_webui.apps.retrieval.utils import get_rag_context, rag_template - from open_webui.apps.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, get_event_call, get_event_emitter, ) - +from open_webui.apps.webui.internal.db import Session from open_webui.apps.webui.main import ( app as webui_app, generate_function_chat_completion, - get_pipe_models, + get_all_models as get_open_webui_models, ) -from open_webui.apps.webui.internal.db import Session - -from open_webui.apps.webui.models.auths import Auths from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.users import UserModel, Users - from open_webui.apps.webui.utils import load_function_module_by_id - -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app - -from authlib.integrations.starlette_client import OAuth -from authlib.oidc.core import UserInfo - - from open_webui.config import ( CACHE_DIR, CORS_ALLOW_ORIGIN, @@ -67,13 +71,11 @@ from open_webui.config import ( ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT, ENABLE_MODEL_FILTER, - ENABLE_OAUTH_SIGNUP, ENABLE_OLLAMA_API, ENABLE_OPENAI_API, ENV, FRONTEND_BUILD_DIR, MODEL_FILTER_LIST, - OAUTH_MERGE_ACCOUNTS_BY_EMAIL, OAUTH_PROVIDERS, ENABLE_SEARCH_QUERY, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, @@ -81,15 +83,15 @@ from open_webui.config import ( TASK_MODEL, TASK_MODEL_EXTERNAL, TITLE_GENERATION_PROMPT_TEMPLATE, + TAGS_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, WEBHOOK_URL, WEBUI_AUTH, WEBUI_NAME, AppConfig, - run_migrations, reset_config, ) -from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES +from open_webui.constants import TASKS from open_webui.env import ( CHANGELOG, GLOBAL_LOG_LEVEL, @@ -102,64 +104,41 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SECURE, WEBUI_URL, RESET_CONFIG_ON_START, + OFFLINE_MODE, ) -from fastapi import ( - Depends, - FastAPI, - File, - Form, - HTTPException, - Request, - UploadFile, - status, -) -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel -from sqlalchemy import text -from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.sessions import SessionMiddleware -from starlette.responses import RedirectResponse, Response, StreamingResponse - -from open_webui.utils.security_headers import SecurityHeadersMiddleware - from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, - parse_duration, prepend_to_first_user_message_content, ) -from open_webui.utils.task import ( - moa_response_generation_template, - search_query_generation_template, - title_generation_template, - tools_function_calling_generation_template, -) -from open_webui.utils.tools import get_tools -from open_webui.utils.utils import ( - create_token, - decode_token, - get_admin_user, - get_current_user, - get_http_authorization_cred, - get_password_hash, - get_verified_user, -) -from open_webui.utils.webhook import post_webhook - +from open_webui.utils.oauth import oauth_manager from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.response import ( convert_response_ollama_to_openai, convert_streaming_response_ollama_to_openai, ) +from open_webui.utils.security_headers import SecurityHeadersMiddleware +from open_webui.utils.task import ( + moa_response_generation_template, + tags_generation_template, + search_query_generation_template, + emoji_generation_template, + title_generation_template, + tools_function_calling_generation_template, +) +from open_webui.utils.tools import get_tools +from open_webui.utils.utils import ( + decode_token, + get_admin_user, + get_current_user, + get_http_authorization_cred, + get_verified_user, +) if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -178,14 +157,14 @@ class SPAStaticFiles(StaticFiles): print( rf""" - ___ __ __ _ _ _ ___ + ___ __ __ _ _ _ ___ / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| -| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | -| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | +| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | +| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___| - |_| + |_| + - v{VERSION} - building the best open-source AI user interface. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} https://github.com/open-webui/open-webui @@ -216,10 +195,10 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.WEBHOOK_URL = WEBHOOK_URL - app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE ) @@ -577,7 +556,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): } # Initialize data_items to store additional data to be sent to the client - # Initalize contexts and citation + # Initialize contexts and citation data_items = [] contexts = [] citations = [] @@ -689,6 +668,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) + ################################## # # Pipeline Middleware @@ -824,6 +804,32 @@ class PipelineMiddleware(BaseHTTPMiddleware): app.add_middleware(PipelineMiddleware) +from urllib.parse import urlencode, parse_qs, urlparse + + +class RedirectMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Check if the request is a GET request + if request.method == "GET": + path = request.url.path + query_params = dict(parse_qs(urlparse(str(request.url)).query)) + + # Check for the specific watch path and the presence of 'v' parameter + if path.endswith("/watch") and "v" in query_params: + video_id = query_params["v"][0] # Extract the first 'v' parameter + encoded_video_id = urlencode({"youtube": video_id}) + redirect_url = f"/?{encoded_video_id}" + return RedirectResponse(url=redirect_url) + + # Proceed with the normal flow of other requests + response = await call_next(request) + return response + + +# Add the middleware to the app +app.add_middleware(RedirectMiddleware) + + app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -900,12 +906,10 @@ webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION async def get_all_models(): # TODO: Optimize this function - pipe_models = [] + open_webui_models = [] openai_models = [] ollama_models = [] - pipe_models = await get_pipe_models() - if app.state.config.ENABLE_OPENAI_API: openai_models = await get_openai_models() openai_models = openai_models["data"] @@ -924,7 +928,13 @@ async def get_all_models(): for model in ollama_models["models"] ] - models = pipe_models + openai_models + ollama_models + open_webui_models = await get_open_webui_models() + + models = open_webui_models + openai_models + ollama_models + + # If there are no models, return an empty list + if len([model for model in models if model["owned_by"] != "arena"]) == 0: + return [] global_action_ids = [ function.id for function in Functions.get_global_action_functions() @@ -963,11 +973,13 @@ async def get_all_models(): owned_by = model["owned_by"] if "pipe" in model: pipe = model["pipe"] - - if "info" in model and "meta" in model["info"]: - action_ids.extend(model["info"]["meta"].get("actionIds", [])) break + if custom_model.meta: + meta = custom_model.meta.model_dump() + if "actionIds" in meta: + action_ids.extend(meta["actionIds"]) + models.append( { "id": custom_model.id, @@ -1070,7 +1082,9 @@ async def get_models(user=Depends(get_verified_user)): @app.post("/api/chat/completions") -async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): +async def generate_chat_completions( + form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False +): model_id = form_data["model"] if model_id not in app.state.MODELS: @@ -1079,7 +1093,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) - if app.state.config.ENABLE_MODEL_FILTER: + if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -1087,6 +1101,53 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) model = app.state.MODELS[model_id] + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" + and not model.get("info", {}).get("meta", {}).get("hidden", False) + and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" + and not model.get("info", {}).get("meta", {}).get("hidden", False) + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completions( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **( + await generate_chat_completions(form_data, user, bypass_filter=True) + ), + "selected_model_id": selected_model_id, + } if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": @@ -1398,6 +1459,7 @@ async def get_task_config(user=Depends(get_verified_user)): "TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -1408,6 +1470,7 @@ class TaskConfigForm(BaseModel): TASK_MODEL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str] TITLE_GENERATION_PROMPT_TEMPLATE: str + TAGS_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str ENABLE_SEARCH_QUERY: bool TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str @@ -1420,6 +1483,10 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) + app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( + form_data.TAGS_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE ) @@ -1432,6 +1499,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -1459,7 +1527,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE else: - template = """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. Examples of titles: 📉 Stock Market Trends @@ -1469,11 +1537,13 @@ Remote Work Productivity Tips Artificial Intelligence in Healthcare 🎮 Video Game Development Insights -Prompt: {{prompt:middletruncate:8000}}""" + +{{MESSAGES:END:2}} +""" content = title_generation_template( template, - form_data["prompt"], + form_data["messages"], { "name": user.name, "location": user.info.get("location") if user.info else None, @@ -1516,6 +1586,75 @@ Prompt: {{prompt:middletruncate:8000}}""" return await generate_chat_completions(form_data=payload, user=user) +@app.post("/api/task/tags/completions") +async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): + print("generate_chat_tags") + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id(model_id) + print(task_model_id) + + if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": + template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE + else: + template = """### Task: +Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags. + +### Guidelines: +- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) +- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation +- If content is too short (less than 3 messages) or too diverse, use only ["General"] +- Use the chat's primary language; default to English if multilingual +- Prioritize accuracy over specificity + +### Output: +JSON format: { "tags": ["tag1", "tag2", "tag3"] } + +### Chat History: + +{{MESSAGES:END:6}} +""" + + content = tags_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + print("content", content) + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data}, + } + log.debug(payload) + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + @app.post("/api/task/query/completions") async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): print("generate_search_query") @@ -1616,7 +1755,7 @@ Your task is to reflect the speaker's likely facial expression through a fitting Message: """{{prompt}}""" ''' - content = title_generation_template( + content = emoji_generation_template( template, form_data["prompt"], { @@ -2181,6 +2320,11 @@ async def get_app_changelog(): @app.get("/api/version/updates") async def get_app_latest_release_version(): + if OFFLINE_MODE: + log.debug( + f"Offline mode is enabled, returning current version as latest version" + ) + return {"current": VERSION, "latest": VERSION} try: timeout = aiohttp.ClientTimeout(total=1) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: @@ -2201,20 +2345,6 @@ async def get_app_latest_release_version(): # OAuth Login & Callback ############################ -oauth = OAuth() - -for provider_name, provider_config in OAUTH_PROVIDERS.items(): - oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - redirect_uri=provider_config["redirect_uri"], - ) - # SessionMiddleware is used by authlib for oauth if len(OAUTH_PROVIDERS) > 0: app.add_middleware( @@ -2228,16 +2358,7 @@ if len(OAUTH_PROVIDERS) > 0: @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - # If the provider has a custom redirect URL, use that, otherwise automatically generate one - redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( - "oauth_callback", provider=provider - ) - client = oauth.create_client(provider) - if client is None: - raise HTTPException(404) - return await client.authorize_redirect(request, redirect_uri) + return await oauth_manager.handle_login(provider, request) # OAuth login logic is as follows: @@ -2245,119 +2366,10 @@ async def oauth_login(provider: str, request: Request): # 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth # - This is considered insecure in general, as OAuth providers do not always verify email addresses # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user -# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken +# - Email addresses are considered unique, so we fail registration if the email address is already taken @app.get("/oauth/{provider}/callback") async def oauth_callback(provider: str, request: Request, response: Response): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - client = oauth.create_client(provider) - try: - token = await client.authorize_access_token(request) - except Exception as e: - log.warning(f"OAuth callback error: {e}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - user_data: UserInfo = token["userinfo"] - - sub = user_data.get("sub") - if not sub: - log.warning(f"OAuth callback failed, sub is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - provider_sub = f"{provider}@{sub}" - email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "").lower() - # We currently mandate that email addresses are provided - if not email: - log.warning(f"OAuth callback failed, email is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) - - if not user: - # If the user does not exist, check if merging is enabled - if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: - # Check if the user exists by email - user = Users.get_user_by_email(email) - if user: - # Update the user with the new oauth sub - Users.update_user_oauth_sub_by_id(user.id, provider_sub) - - if not user: - # If the user does not exist, check if signups are enabled - if ENABLE_OAUTH_SIGNUP.value: - # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) - if existing_user: - raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - - picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM - picture_url = user_data.get(picture_claim, "") - if picture_url: - # Download the profile image into a base64 string - try: - async with aiohttp.ClientSession() as session: - async with session.get(picture_url) as resp: - picture = await resp.read() - base64_encoded_picture = base64.b64encode(picture).decode( - "utf-8" - ) - guessed_mime_type = mimetypes.guess_type(picture_url)[0] - if guessed_mime_type is None: - # assume JPG, browsers are tolerant enough of image formats - guessed_mime_type = "image/jpeg" - picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" - except Exception as e: - log.error(f"Error downloading profile image '{picture_url}': {e}") - picture_url = "" - if not picture_url: - picture_url = "/user.png" - username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - role = ( - "admin" - if Users.get_num_users() == 0 - else webui_app.state.config.DEFAULT_USER_ROLE - ) - user = Auths.insert_new_auth( - email=email, - password=get_password_hash( - str(uuid.uuid4()) - ), # Random password, not used - name=user_data.get(username_claim, "User"), - profile_image_url=picture_url, - role=role, - oauth_sub=provider_sub, - ) - - if webui_app.state.config.WEBHOOK_URL: - post_webhook( - webui_app.state.config.WEBHOOK_URL, - WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), - }, - ) - else: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - - jwt_token = create_token( - data={"id": user.id}, - expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN), - ) - - # Set the cookie token - response.set_cookie( - key="token", - value=jwt_token, - httponly=True, # Ensures the cookie is not accessible via JavaScript - ) - - # Redirect back to the frontend with the JWT token - redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + return await oauth_manager.handle_callback(provider, request, response) @app.get("/manifest.json") @@ -2416,6 +2428,7 @@ async def healthcheck_with_db(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") + if os.path.exists(FRONTEND_BUILD_DIR): mimetypes.add_type("text/javascript", ".js") app.mount( diff --git a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py new file mode 100644 index 0000000000..8a0ab1b491 --- /dev/null +++ b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py @@ -0,0 +1,151 @@ +"""Migrate tags + +Revision ID: 1af9b942657b +Revises: 242a2047eae0 +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update, column +from sqlalchemy.engine.reflection import Inspector + +import json + +revision = "1af9b942657b" +down_revision = "242a2047eae0" +branch_labels = None +depends_on = None + + +def upgrade(): + # Setup an inspection on the existing table to avoid issues + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + + # Clean up potential leftover temp table from previous failures + conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag")) + + # Check if the 'tag' table exists + tables = inspector.get_table_names() + + # Step 1: Modify Tag table using batch mode for SQLite support + if "tag" in tables: + # Get the current columns in the 'tag' table + columns = [col["name"] for col in inspector.get_columns("tag")] + + # Get any existing unique constraints on the 'tag' table + current_constraints = inspector.get_unique_constraints("tag") + + with op.batch_alter_table("tag", schema=None) as batch_op: + # Check if the unique constraint already exists + if not any( + constraint["name"] == "uq_id_user_id" + for constraint in current_constraints + ): + # Create unique constraint if it doesn't exist + batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) + + # Check if the 'data' column exists before trying to drop it + if "data" in columns: + batch_op.drop_column("data") + + # Check if the 'meta' column needs to be created + if "meta" not in columns: + # Add the 'meta' column if it doesn't already exist + batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True)) + + tag = table( + "tag", + column("id", sa.String()), + column("name", sa.String()), + column("user_id", sa.String()), + column("meta", sa.JSON()), + ) + + # Step 2: Migrate tags + conn = op.get_bind() + result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id)) + + tag_updates = {} + for row in result: + new_id = row.name.replace(" ", "_").lower() + tag_updates[row.id] = new_id + + for tag_id, new_tag_id in tag_updates.items(): + print(f"Updating tag {tag_id} to {new_tag_id}") + if new_tag_id == "pinned": + # delete tag + delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) + conn.execute(delete_stmt) + else: + # Check if the new_tag_id already exists in the database + existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id) + existing_tag_result = conn.execute(existing_tag_query).fetchone() + + if existing_tag_result: + # Handle duplicate case: the new_tag_id already exists + print( + f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates." + ) + # Option 1: Delete the current tag if an update to new_tag_id would cause duplication + delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) + conn.execute(delete_stmt) + else: + update_stmt = sa.update(tag).where(tag.c.id == tag_id) + update_stmt = update_stmt.values(id=new_tag_id) + conn.execute(update_stmt) + + # Add columns `pinned` and `meta` to 'chat' + op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True)) + op.add_column( + "chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}") + ) + + chatidtag = table( + "chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String()) + ) + chat = table( + "chat", + column("id", sa.String()), + column("pinned", sa.Boolean()), + column("meta", sa.JSON()), + ) + + # Fetch existing tags + conn = op.get_bind() + result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name)) + + chat_updates = {} + for row in result: + chat_id = row.chat_id + tag_name = row.tag_name.replace(" ", "_").lower() + + if tag_name == "pinned": + # Specifically handle 'pinned' tag + if chat_id not in chat_updates: + chat_updates[chat_id] = {"pinned": True, "meta": {}} + else: + chat_updates[chat_id]["pinned"] = True + else: + if chat_id not in chat_updates: + chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}} + else: + tags = chat_updates[chat_id]["meta"].get("tags", []) + tags.append(tag_name) + + chat_updates[chat_id]["meta"]["tags"] = list(set(tags)) + + # Update chats based on accumulated changes + for chat_id, updates in chat_updates.items(): + update_stmt = sa.update(chat).where(chat.c.id == chat_id) + update_stmt = update_stmt.values( + meta=updates.get("meta", {}), pinned=updates.get("pinned", False) + ) + conn.execute(update_stmt) + pass + + +def downgrade(): + pass diff --git a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py new file mode 100644 index 0000000000..596703dc2c --- /dev/null +++ b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py @@ -0,0 +1,82 @@ +"""Update chat table + +Revision ID: 242a2047eae0 +Revises: 6a39f3d8e55c +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update + +import json + +revision = "242a2047eae0" +down_revision = "6a39f3d8e55c" +branch_labels = None +depends_on = None + + +def upgrade(): + # Step 1: Rename current 'chat' column to 'old_chat' + op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text) + + # Step 2: Add new 'chat' column of type JSON + op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True)) + + # Step 3: Migrate data from 'old_chat' to 'chat' + chat_table = table( + "chat", + sa.Column("id", sa.String, primary_key=True), + sa.Column("old_chat", sa.Text), + sa.Column("chat", sa.JSON()), + ) + + # - Selecting all data from the table + connection = op.get_bind() + results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat)) + for row in results: + try: + # Convert text JSON to actual JSON object, assuming the text is in JSON format + json_data = json.loads(row.old_chat) + except json.JSONDecodeError: + json_data = None # Handle cases where the text cannot be converted to JSON + + connection.execute( + sa.update(chat_table) + .where(chat_table.c.id == row.id) + .values(chat=json_data) + ) + + # Step 4: Drop 'old_chat' column + op.drop_column("chat", "old_chat") + + +def downgrade(): + # Step 1: Add 'old_chat' column back as Text + op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True)) + + # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat' + chat_table = table( + "chat", + sa.Column("id", sa.String, primary_key=True), + sa.Column("chat", sa.JSON()), + sa.Column("old_chat", sa.Text()), + ) + + connection = op.get_bind() + results = connection.execute(select(chat_table.c.id, chat_table.c.chat)) + for row in results: + text_data = json.dumps(row.chat) if row.chat is not None else None + connection.execute( + sa.update(chat_table) + .where(chat_table.c.id == row.id) + .values(old_chat=text_data) + ) + + # Step 3: Remove the new 'chat' JSON column + op.drop_column("chat", "chat") + + # Step 4: Rename 'old_chat' back to 'chat' + op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text) diff --git a/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py new file mode 100644 index 0000000000..6e010424b0 --- /dev/null +++ b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py @@ -0,0 +1,81 @@ +"""Update tags + +Revision ID: 3ab32c4b8f59 +Revises: 1af9b942657b +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update, column +from sqlalchemy.engine.reflection import Inspector + +import json + +revision = "3ab32c4b8f59" +down_revision = "1af9b942657b" +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + + # Inspecting the 'tag' table constraints and structure + existing_pk = inspector.get_pk_constraint("tag") + unique_constraints = inspector.get_unique_constraints("tag") + existing_indexes = inspector.get_indexes("tag") + + print(f"Primary Key: {existing_pk}") + print(f"Unique Constraints: {unique_constraints}") + print(f"Indexes: {existing_indexes}") + + with op.batch_alter_table("tag", schema=None) as batch_op: + # Drop existing primary key constraint if it exists + if existing_pk and existing_pk.get("constrained_columns"): + pk_name = existing_pk.get("name") + if pk_name: + print(f"Dropping primary key constraint: {pk_name}") + batch_op.drop_constraint(pk_name, type_="primary") + + # Now create the new primary key with the combination of 'id' and 'user_id' + print("Creating new primary key with 'id' and 'user_id'.") + batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"]) + + # Drop unique constraints that could conflict with the new primary key + for constraint in unique_constraints: + if ( + constraint["name"] == "uq_id_user_id" + ): # Adjust this name according to what is actually returned by the inspector + print(f"Dropping unique constraint: {constraint['name']}") + batch_op.drop_constraint(constraint["name"], type_="unique") + + for index in existing_indexes: + if index["unique"]: + if not any( + constraint["name"] == index["name"] + for constraint in unique_constraints + ): + # You are attempting to drop unique indexes + print(f"Dropping unique index: {index['name']}") + batch_op.drop_index(index["name"]) + + +def downgrade(): + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + + current_pk = inspector.get_pk_constraint("tag") + + with op.batch_alter_table("tag", schema=None) as batch_op: + # Drop the current primary key first, if it matches the one we know we added in upgrade + if current_pk and "pk_id_user_id" == current_pk.get("name"): + batch_op.drop_constraint("pk_id_user_id", type_="primary") + + # Restore the original primary key + batch_op.create_primary_key("pk_id", ["id"]) + + # Since primary key on just 'id' is restored, we now add back any unique constraints if necessary + batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) diff --git a/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py b/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py new file mode 100644 index 0000000000..16f7967c8e --- /dev/null +++ b/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py @@ -0,0 +1,67 @@ +"""Update folder table and change DateTime to BigInteger for timestamp fields + +Revision ID: 4ace53fd72c8 +Revises: af906e964978 +Create Date: 2024-10-23 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "4ace53fd72c8" +down_revision = "af906e964978" +branch_labels = None +depends_on = None + + +def upgrade(): + # Perform safe alterations using batch operation + with op.batch_alter_table("folder", schema=None) as batch_op: + # Step 1: Remove server defaults for created_at and updated_at + batch_op.alter_column( + "created_at", + server_default=None, # Removing server default + ) + batch_op.alter_column( + "updated_at", + server_default=None, # Removing server default + ) + + # Step 2: Change the column types to BigInteger for created_at + batch_op.alter_column( + "created_at", + type_=sa.BigInteger(), + existing_type=sa.DateTime(), + existing_nullable=False, + postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL + ) + + # Change the column types to BigInteger for updated_at + batch_op.alter_column( + "updated_at", + type_=sa.BigInteger(), + existing_type=sa.DateTime(), + existing_nullable=False, + postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL + ) + + +def downgrade(): + # Downgrade: Convert columns back to DateTime and restore defaults + with op.batch_alter_table("folder", schema=None) as batch_op: + batch_op.alter_column( + "created_at", + type_=sa.DateTime(), + existing_type=sa.BigInteger(), + existing_nullable=False, + server_default=sa.func.now(), # Restoring server default on downgrade + ) + batch_op.alter_column( + "updated_at", + type_=sa.DateTime(), + existing_type=sa.BigInteger(), + existing_nullable=False, + server_default=sa.func.now(), # Restoring server default on downgrade + onupdate=sa.func.now(), # Restoring onupdate behavior if it was there + ) diff --git a/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py new file mode 100644 index 0000000000..9116aa3884 --- /dev/null +++ b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py @@ -0,0 +1,51 @@ +"""Add feedback table + +Revision ID: af906e964978 +Revises: c29facfe716b +Create Date: 2024-10-20 17:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa + +# Revision identifiers, used by Alembic. +revision = "af906e964978" +down_revision = "c29facfe716b" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### Create feedback table ### + op.create_table( + "feedback", + sa.Column( + "id", sa.Text(), primary_key=True + ), # Unique identifier for each feedback (TEXT type) + sa.Column( + "user_id", sa.Text(), nullable=True + ), # ID of the user providing the feedback (TEXT type) + sa.Column( + "version", sa.BigInteger(), default=0 + ), # Version of feedback (BIGINT type) + sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type) + sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type) + sa.Column( + "meta", sa.JSON(), nullable=True + ), # Metadata for feedback (JSON type) + sa.Column( + "snapshot", sa.JSON(), nullable=True + ), # snapshot data for feedback (JSON type) + sa.Column( + "created_at", sa.BigInteger(), nullable=False + ), # Feedback creation timestamp (BIGINT representing epoch) + sa.Column( + "updated_at", sa.BigInteger(), nullable=False + ), # Feedback update timestamp (BIGINT representing epoch) + ) + + +def downgrade(): + # ### Drop feedback table ### + op.drop_table("feedback") diff --git a/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py new file mode 100644 index 0000000000..de82854b88 --- /dev/null +++ b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py @@ -0,0 +1,79 @@ +"""Update file table path + +Revision ID: c29facfe716b +Revises: c69f45358db4 +Create Date: 2024-10-20 17:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +import json +from sqlalchemy.sql import table, column +from sqlalchemy import String, Text, JSON, and_ + + +revision = "c29facfe716b" +down_revision = "c69f45358db4" +branch_labels = None +depends_on = None + + +def upgrade(): + # 1. Add the `path` column to the "file" table. + op.add_column("file", sa.Column("path", sa.Text(), nullable=True)) + + # 2. Convert the `meta` column from Text/JSONField to `JSON()` + # Use Alembic's default batch_op for dialect compatibility. + with op.batch_alter_table("file", schema=None) as batch_op: + batch_op.alter_column( + "meta", + type_=sa.JSON(), + existing_type=sa.Text(), + existing_nullable=True, + nullable=True, + postgresql_using="meta::json", + ) + + # 3. Migrate legacy data from `meta` JSONField + # Fetch and process `meta` data from the table, add values to the new `path` column as necessary. + # We will use SQLAlchemy core bindings to ensure safety across different databases. + + file_table = table( + "file", column("id", String), column("meta", JSON), column("path", Text) + ) + + # Create connection to the database + connection = op.get_bind() + + # Get the rows where `meta` has a path and `path` column is null (new column) + # Loop through each row in the result set to update the path + results = connection.execute( + sa.select(file_table.c.id, file_table.c.meta).where( + and_(file_table.c.path.is_(None), file_table.c.meta.isnot(None)) + ) + ).fetchall() + + # Iterate over each row to extract and update the `path` from `meta` column + for row in results: + if "path" in row.meta: + # Extract the `path` field from the `meta` JSON + path = row.meta.get("path") + + # Update the `file` table with the new `path` value + connection.execute( + file_table.update() + .where(file_table.c.id == row.id) + .values({"path": path}) + ) + + +def downgrade(): + # 1. Remove the `path` column + op.drop_column("file", "path") + + # 2. Revert the `meta` column back to Text/JSONField + with op.batch_alter_table("file", schema=None) as batch_op: + batch_op.alter_column( + "meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True + ) diff --git a/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py b/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py new file mode 100644 index 0000000000..83e0dc28ed --- /dev/null +++ b/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py @@ -0,0 +1,50 @@ +"""Add folder table + +Revision ID: c69f45358db4 +Revises: 3ab32c4b8f59 +Create Date: 2024-10-16 02:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "c69f45358db4" +down_revision = "3ab32c4b8f59" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "folder", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("parent_id", sa.Text(), nullable=True), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("items", sa.JSON(), nullable=True), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False), + sa.Column( + "created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + sa.PrimaryKeyConstraint("id", "user_id"), + ) + + op.add_column( + "chat", + sa.Column("folder_id", sa.Text(), nullable=True), + ) + + +def downgrade(): + op.drop_column("chat", "folder_id") + + op.drop_table("folder") diff --git a/backend/open_webui/static/assets/pdf-style.css b/backend/open_webui/static/assets/pdf-style.css new file mode 100644 index 0000000000..db9ac83ddb --- /dev/null +++ b/backend/open_webui/static/assets/pdf-style.css @@ -0,0 +1,319 @@ +/* HTML and Body */ +@font-face { + font-family: 'NotoSans'; + src: url('fonts/NotoSans-Variable.ttf'); +} + +@font-face { + font-family: 'NotoSansJP'; + src: url('fonts/NotoSansJP-Variable.ttf'); +} + +@font-face { + font-family: 'NotoSansKR'; + src: url('fonts/NotoSansKR-Variable.ttf'); +} + +@font-face { + font-family: 'NotoSansSC'; + src: url('fonts/NotoSansSC-Variable.ttf'); +} + +@font-face { + font-family: 'NotoSansSC-Regular'; + src: url('fonts/NotoSansSC-Regular.ttf'); +} + +html { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR', + 'NotoSansSC', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto, + 'Helvetica Neue', Arial, sans-serif; + font-size: 14px; /* Default font size */ + line-height: 1.5; +} + +*, +*::before, +*::after { + box-sizing: inherit; +} + +body { + margin: 0; + color: #212529; + background-color: #fff; + width: auto; +} + +/* Typography */ +h1, +h2, +h3, +h4, +h5, +h6 { + font-weight: 500; + margin: 0; +} + +h1 { + font-size: 2.5rem; +} + +h2 { + font-size: 2rem; +} + +h3 { + font-size: 1.75rem; +} + +h4 { + font-size: 1.5rem; +} + +h5 { + font-size: 1.25rem; +} + +h6 { + font-size: 1rem; +} + +p { + margin-top: 0; + margin-bottom: 1rem; +} + +/* Grid System */ +.container { + width: 100%; + padding-right: 15px; + padding-left: 15px; + margin-right: auto; + margin-left: auto; +} + +/* Utilities */ +.text-center { + text-align: center; +} + +/* Additional Text Utilities */ +.text-muted { + color: #6c757d; /* Muted text color */ +} + +/* Small Text */ +small { + font-size: 80%; /* Smaller font size relative to the base */ + color: #6c757d; /* Lighter text color for secondary information */ + margin-bottom: 0; + margin-top: 0; +} + +/* Strong Element Styles */ +strong { + font-weight: bolder; /* Ensures the text is bold */ + color: inherit; /* Inherits the color from its parent element */ +} + +/* link */ +a { + color: #007bff; + text-decoration: none; + background-color: transparent; +} + +a:hover { + color: #0056b3; + text-decoration: underline; +} + +/* General styles for lists */ +ol, +ul, +li { + padding-left: 40px; /* Increase padding to move bullet points to the right */ + margin-left: 20px; /* Indent lists from the left */ +} + +/* Ordered list styles */ +ol { + list-style-type: decimal; /* Use numbers for ordered lists */ + margin-bottom: 10px; /* Space after each list */ +} + +ol li { + margin-bottom: 0.5rem; /* Space between ordered list items */ +} + +/* Unordered list styles */ +ul { + list-style-type: disc; /* Use bullets for unordered lists */ + margin-bottom: 10px; /* Space after each list */ +} + +ul li { + margin-bottom: 0.5rem; /* Space between unordered list items */ +} + +/* List item styles */ +li { + margin-bottom: 5px; /* Space between list items */ + line-height: 1.5; /* Line height for better readability */ +} + +/* Nested lists */ +ol ol, +ol ul, +ul ol, +ul ul { + padding-left: 20px; + margin-left: 30px; /* Further indent nested lists */ + margin-bottom: 0; /* Remove extra margin at the bottom of nested lists */ +} + +/* Code blocks */ +pre { + background-color: #f4f4f4; + padding: 10px; + overflow-x: auto; + max-width: 100%; /* Ensure it doesn't overflow the page */ + width: 80%; /* Set a specific width for a container-like appearance */ + margin: 0 1em; /* Center the pre block */ + box-sizing: border-box; /* Include padding in the width */ + border: 1px solid #ccc; /* Optional: Add a border for better definition */ + border-radius: 4px; /* Optional: Add rounded corners */ +} + +code { + font-family: 'Courier New', Courier, monospace; + background-color: #f4f4f4; + padding: 2px 4px; + border-radius: 4px; + box-sizing: border-box; /* Include padding in the width */ +} + +.message { + margin-top: 8px; + margin-bottom: 8px; + max-width: 100%; + overflow-wrap: break-word; +} + +/* Table Styles */ +table { + width: 100%; + margin-bottom: 1rem; + color: #212529; + border-collapse: collapse; /* Removes the space between borders */ +} + +th, +td { + margin: 0; + padding: 0.75rem; + vertical-align: top; + border-top: 1px solid #dee2e6; +} + +thead th { + vertical-align: bottom; + border-bottom: 2px solid #dee2e6; +} + +tbody + tbody { + border-top: 2px solid #dee2e6; +} + +/* markdown-section styles */ +.markdown-section blockquote, +.markdown-section h1, +.markdown-section h2, +.markdown-section h3, +.markdown-section h4, +.markdown-section h5, +.markdown-section h6, +.markdown-section p, +.markdown-section pre, +.markdown-section table, +.markdown-section ul { + /* Give most block elements margin top and bottom */ + margin-top: 1rem; +} + +/* Remove top margin if it's the first child */ +.markdown-section blockquote:first-child, +.markdown-section h1:first-child, +.markdown-section h2:first-child, +.markdown-section h3:first-child, +.markdown-section h4:first-child, +.markdown-section h5:first-child, +.markdown-section h6:first-child, +.markdown-section p:first-child, +.markdown-section pre:first-child, +.markdown-section table:first-child, +.markdown-section ul:first-child { + margin-top: 0; +} + +/* Remove top margin of