Merge pull request #6002 from open-webui/dev

0.3.33
This commit is contained in:
Timothy Jaeryang Baek 2024-10-24 13:36:19 -07:00 committed by GitHub
commit 99dd7fb5a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
246 changed files with 19207 additions and 6854 deletions

View file

@ -5,6 +5,39 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.33] - 2024-10-24
### Added
- **🏆 Evaluation Leaderboard**: Easily track your performance through a new leaderboard system where your ratings contribute to a real-time ranking based on the Elo system. Sibling responses (regenerations, many model chats) are required for your ratings to count in the leaderboard. Additionally, you can opt-in to share your feedback history and be part of the community-wide leaderboard. Expect further improvements as we refine the algorithm—help us build the best community leaderboard!
- **⚔️ Arena Model Evaluation**: Enable blind A/B testing of models directly from Admin Settings > Evaluation for a true side-by-side comparison. Ideal for pinpointing the best model for your needs.
- **🎯 Topic-Based Leaderboard**: Discover more accurate rankings with experimental topic-based reranking, which adjusts leaderboard standings based on tag similarity in feedback. Get more relevant insights based on specific topics!
- **📁 Folders Support for Chats**: Organize your chats better by grouping them into folders. Drag and drop chats between folders and export them seamlessly for easy sharing or analysis.
- **📤 Easy Chat Import via Drag & Drop**: Save time by simply dragging and dropping chat exports (JSON) directly onto the sidebar to import them into your workspace—streamlined, efficient, and intuitive!
- **📚 Enhanced Knowledge Collection**: Now, you can reference individual files from a knowledge collection—ideal for more precise Retrieval-Augmented Generations (RAG) queries and document analysis.
- **🏷️ Enhanced Tagging System**: Tags now take up less space! Utilize the new 'tag:' query system to manage, search, and organize your conversations more effectively without cluttering the interface.
- **🧠 Auto-Tagging for Chats**: Your conversations are now automatically tagged for improved organization, mirroring the efficiency of auto-generated titles.
- **🔍 Backend Chat Query System**: Chat filtering has become more efficient, now handled through the backend\*\* instead of your browser, improving search performance and accuracy.
- **🎮 Revamped Playground**: Experience a refreshed and optimized Playground for smoother testing, tweaks, and experimentation of your models and tools.
- **🧩 Token-Based Text Splitter**: Introducing token-based text splitting (tiktoken), giving you more precise control over how text is processed. Previously, only character-based splitting was available.
- **🔢 Ollama Batch Embeddings**: Leverage new batch embedding support for improved efficiency and performance with Ollama embedding models.
- **🔍 Enhanced Add Text Content Modal**: Enjoy a cleaner, more intuitive workflow for adding and curating knowledge content with an upgraded input modal from our Knowledge workspace.
- **🖋️ Rich Text Input for Chats**: Make your chat inputs more dynamic with support for rich text formatting. Your conversations just got a lot more polished and professional.
- **⚡ Faster Whisper Model Configurability**: Customize your local faster whisper model directly from the WebUI.
- **☁️ Experimental S3 Support**: Enable stateless WebUI instances with S3 support, greatly enhancing scalability and balancing heavy workloads.
- **🔕 Disable Update Toast**: Now you can streamline your workspace even further—choose to disable update notifications for a more focused experience.
- **🌟 RAG Citation Relevance Percentage**: Easily assess citation accuracy with the addition of relevance percentages in RAG results.
- **⚙️ Mermaid Copy Button**: Mermaid diagrams now come with a handy copy button, simplifying the extraction and use of diagram contents directly in your workflow.
- **🎨 UI Redesign**: Major interface redesign that will make navigation smoother, keep your focus where it matters, and ensure a modern look.
### Fixed
- **🎙️ Voice Note Mic Stopping Issue**: Fixed the issue where the microphone stayed active after ending a voice note recording, ensuring your audio workflow runs smoothly.
### Removed
- **👋 Goodbye Sidebar Tags**: Sidebar tag clutter is gone. Weve shifted tag buttons to more effective query-based tag filtering for a sleeker, more agile interface.
## [0.3.32] - 2024-10-06
### Added

View file

@ -1,6 +1,6 @@
# syntax=docker/dockerfile:1
# Initialize device type args
# use build args in the docker build commmand with --build-arg="BUILDARG=true"
# use build args in the docker build command with --build-arg="BUILDARG=true"
ARG USE_CUDA=false
ARG USE_OLLAMA=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
@ -11,6 +11,10 @@ ARG USE_CUDA_VER=cu121
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL=""
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
ARG BUILD_HASH=dev-build
# Override at your own risk - non-root configurations are untested
ARG UID=0
@ -72,6 +76,10 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
## Tiktoken model settings ##
ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
## Hugging Face download cache ##
ENV HF_HOME="/app/backend/data/cache/embedding/models"
@ -131,11 +139,13 @@ RUN pip3 install uv && \
uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
fi; \
chown -R $UID:$GID /app/backend/data/

View file

@ -18,7 +18,7 @@ If you're experiencing connection issues, its often due to the WebUI docker c
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
```
### Error on Slow Reponses for Ollama
### Error on Slow Responses for Ollama
Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.

View file

@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.faster_whisper_model = None
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
def set_faster_whisper_model(model: str, auto_update: bool = False):
if model and app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
else:
app.state.faster_whisper_model = None
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
@ -99,6 +127,7 @@ class STTConfigForm(BaseModel):
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)):
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
@ -176,6 +206,8 @@ async def update_audio_config(
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
return {
"tts": {
@ -194,6 +226,7 @@ async def update_audio_config(
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
@ -367,27 +400,10 @@ def transcribe(file_path):
id = filename.split(".")[0]
if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
if app.state.faster_whisper_model is None:
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
model = app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
@ -395,7 +411,6 @@ def transcribe(file_path):
)
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
# save the transcript to a json file
@ -403,7 +418,7 @@ def transcribe(file_path):
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
log.debug(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
@ -417,7 +432,7 @@ def transcribe(file_path):
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": app.state.config.STT_MODEL}
print(files, data)
log.debug(files, data)
r = None
try:
@ -450,7 +465,7 @@ def transcribe(file_path):
except Exception:
error_detail = f"External: {e}"
raise error_detail
raise Exception(error_detail)
@app.post("/transcriptions")

View file

@ -125,22 +125,34 @@ async def comfyui_generate_image(
workflow[node_id]["inputs"][node.key] = model
elif node.type == "prompt":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["text"] = payload.prompt
workflow[node_id]["inputs"][
node.key if node.key else "text"
] = payload.prompt
elif node.type == "negative_prompt":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["text"] = payload.negative_prompt
workflow[node_id]["inputs"][
node.key if node.key else "text"
] = payload.negative_prompt
elif node.type == "width":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["width"] = payload.width
workflow[node_id]["inputs"][
node.key if node.key else "width"
] = payload.width
elif node.type == "height":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["height"] = payload.height
workflow[node_id]["inputs"][
node.key if node.key else "height"
] = payload.height
elif node.type == "n":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["batch_size"] = payload.n
workflow[node_id]["inputs"][
node.key if node.key else "batch_size"
] = payload.n
elif node.type == "steps":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["steps"] = payload.steps
workflow[node_id]["inputs"][
node.key if node.key else "steps"
] = payload.steps
elif node.type == "seed":
seed = (
payload.seed

View file

@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel):
model: str
input: str
truncate: Optional[bool]
input: list[str] | str
truncate: Optional[bool] = None
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None
@ -560,48 +560,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return generate_ollama_batch_embeddings(form_data, url_idx)
@app.post("/api/embeddings")
@ -611,48 +570,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
def generate_ollama_embeddings(
@ -692,7 +610,64 @@ def generate_ollama_embeddings(
log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data:
return data["embedding"]
return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
def generate_ollama_batch_embeddings(
form_data: GenerateEmbedForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
data = r.json()
log.info(f"generate_ollama_batch_embeddings {data}")
if "embeddings" in data:
return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
@ -788,8 +763,7 @@ async def generate_chat_completion(
user=Depends(get_verified_user),
):
payload = {**form_data.model_dump(exclude_none=True)}
log.debug(f"{payload = }")
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
if "metadata" in payload:
del payload["metadata"]
@ -824,7 +798,7 @@ async def generate_chat_completion(
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
log.debug(payload)
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
return await post_streaming_url(
f"{url}/api/chat",

View file

@ -18,7 +18,10 @@ from open_webui.config import (
OPENAI_API_KEYS,
AppConfig,
)
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
@ -179,7 +182,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=3)
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
@ -237,9 +240,7 @@ def merge_models_lists(model_lists):
def is_openai_api_disabled():
api_keys = app.state.config.OPENAI_API_KEYS
no_keys = len(api_keys) == 1 and api_keys[0] == ""
return no_keys or not app.state.config.ENABLE_OPENAI_API
return not app.state.config.ENABLE_OPENAI_API
async def get_all_models_raw() -> list:

View file

@ -15,6 +15,9 @@ from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, sta
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.knowledge import Knowledges
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
# Document loaders
@ -47,6 +50,8 @@ from open_webui.apps.retrieval.utils import (
from open_webui.apps.webui.models.files import Files
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
TIKTOKEN_ENCODING_NAME,
RAG_TEXT_SPLITTER,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
@ -63,7 +68,7 @@ from open_webui.config import (
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
@ -102,7 +107,7 @@ from open_webui.utils.misc import (
)
from open_webui.utils.utils import get_admin_user, get_verified_user
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_community.document_loaders import (
YoutubeLoader,
)
@ -129,12 +134,15 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@ -171,9 +179,9 @@ def update_embedding_model(
auto_update: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
from sentence_transformers import SentenceTransformer
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
app.state.sentence_transformer_ef = SentenceTransformer(
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@ -233,7 +241,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
app.add_middleware(
@ -267,7 +275,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
}
@ -277,10 +285,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@ -296,13 +304,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@app.post("/embedding/update")
@ -320,11 +328,7 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
form_data.openai_config.batch_size
if form_data.openai_config.batch_size
else 1
)
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@ -334,17 +338,17 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
return {
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@ -388,18 +392,19 @@ async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"file": {
"max_size": app.state.config.FILE_MAX_SIZE,
"max_count": app.state.config.FILE_MAX_COUNT,
},
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"text_splitter": app.state.config.TEXT_SPLITTER,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"file": {
"max_size": app.state.config.FILE_MAX_SIZE,
"max_count": app.state.config.FILE_MAX_COUNT,
},
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
@ -438,6 +443,7 @@ class ContentExtractionConfig(BaseModel):
class ChunkParamUpdateForm(BaseModel):
text_splitter: Optional[str] = None
chunk_size: int
chunk_overlap: int
@ -497,6 +503,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None:
app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
@ -543,6 +550,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"text_splitter": app.state.config.TEXT_SPLITTER,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
@ -603,11 +611,10 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.config.RAG_TEMPLATE = (
form_data.template if form_data.template != "" else DEFAULT_RAG_TEMPLATE
)
app.state.config.RAG_TEMPLATE = form_data.template
app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False
)
@ -645,25 +652,48 @@ def save_docs_to_vector_db(
filter={"hash": metadata["hash"]},
)
if result:
if result is not None:
existing_doc_ids = result.ids[0]
if existing_doc_ids:
log.info(f"Document with hash {metadata['hash']} already exists")
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
if app.state.config.TEXT_SPLITTER in ["", "character"]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
elif app.state.config.TEXT_SPLITTER == "token":
text_splitter = TokenTextSplitter(
encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
docs = text_splitter.split_documents(docs)
if len(docs) == 0:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
texts = [doc.page_content for doc in docs]
metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs]
metadatas = [
{
**doc.metadata,
**(metadata if metadata else {}),
"embedding_config": json.dumps(
{
"engine": app.state.config.RAG_EMBEDDING_ENGINE,
"model": app.state.config.RAG_EMBEDDING_MODEL,
}
),
}
for doc in docs
]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
@ -679,8 +709,10 @@ def save_docs_to_vector_db(
if overwrite:
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
log.info(f"deleting existing collection {collection_name}")
if add is False:
elif add is False:
log.info(
f"collection {collection_name} already exists, overwrite is False and add is False"
)
return True
log.info(f"adding to collection {collection_name}")
@ -690,7 +722,7 @@ def save_docs_to_vector_db(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
embeddings = embedding_function(
@ -767,7 +799,7 @@ def process_file(
collection_name=f"file-{file.id}", filter={"file_id": file.id}
)
if len(result.ids[0]) > 0:
if result is not None and len(result.ids[0]) > 0:
docs = [
Document(
page_content=result.documents[0][idx],
@ -792,15 +824,14 @@ def process_file(
else:
# Process the file and save the content
# Usage: /files/
file_path = file.meta.get("path", None)
file_path = file.path
if file_path:
file_path = Storage.get_file(file_path)
loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
)
docs = loader.load(
file.filename, file.meta.get("content_type"), file_path
)
@ -816,7 +847,6 @@ def process_file(
},
)
]
text_content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {text_content}")
@ -1259,6 +1289,7 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
@app.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
VECTOR_DB_CLIENT.reset()
Knowledges.delete_all_knowledge()
@app.post("/reset/uploads")
@ -1281,28 +1312,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
print(f"The directory {folder} does not exist")
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
return True
@app.post("/reset")
def reset(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
try:
VECTOR_DB_CLIENT.reset()
except Exception as e:
log.exception(e)
return True

View file

@ -12,13 +12,14 @@ from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import DEFAULT_RAG_TEMPLATE
log = logging.getLogger(__name__)
@ -193,7 +194,8 @@ def query_collection(
k=k,
query_embedding=query_embedding,
)
results.append(result.model_dump())
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
@ -238,8 +240,13 @@ def query_collection_with_hybrid_search(
def rag_template(template: str, context: str, query: str):
count = template.count("[context]")
assert "[context]" in template, "RAG template does not contain '[context]'"
if template == "":
template = DEFAULT_RAG_TEMPLATE
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context:
log.debug(
@ -248,14 +255,25 @@ def rag_template(template: str, context: str, query: str):
"nothing, or the user might be trying to hack something."
)
query_placeholders = []
if "[query]" in context:
query_placeholder = f"[query-{str(uuid.uuid4())}]"
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder)
template = template.replace("[context]", context)
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
else:
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
@ -265,39 +283,27 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
batch_size,
embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
func = lambda query: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
key=openai_key if embedding_engine == "openai" else "",
url=openai_url if embedding_engine == "openai" else "",
)
def generate_multiple(query, f):
def generate_multiple(query, func):
if isinstance(query, list):
if embedding_engine == "openai":
embeddings = []
for i in range(0, len(query), batch_size):
embeddings.extend(f(query[i : i + batch_size]))
return embeddings
else:
return [f(q) for q in query]
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings
else:
return f(query)
return func(query)
return lambda query: generate_multiple(query, func)
@ -379,6 +385,8 @@ def get_rag_context(
extracted_collections.extend(collection_names)
if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file})
contexts = []
@ -386,23 +394,37 @@ def get_rag_context(
for context in relevant_contexts:
try:
if "documents" in context:
file_names = list(
set(
[
metadata["name"]
for metadata in context["metadatas"][0]
if metadata is not None and "name" in metadata
]
)
)
contexts.append(
"\n\n".join(
((", ".join(file_names) + ":\n\n") if file_names else "")
+ "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
if "metadatas" in context:
citations.append(
{
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
)
citation = {
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
if "distances" in context and context["distances"]:
citation["distances"] = context["distances"][0]
citations.append(citation)
except Exception as e:
log.exception(e)
print("contexts", contexts)
print("citations", citations)
return contexts, citations
@ -444,20 +466,6 @@ def get_model_path(model: str, update_model: bool = False):
return model
def generate_openai_embeddings(
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
@ -481,6 +489,33 @@ def generate_openai_batch_embeddings(
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text})
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]})
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
import operator
from typing import Optional, Sequence

View file

@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient

View file

@ -109,7 +109,9 @@ class ChromaClient:
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name)
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
@ -127,7 +129,9 @@ class ChromaClient:
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name)
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]

View file

@ -0,0 +1,179 @@
from typing import Optional
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI
NO_LIMIT = 999999999
class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _create_collection(self, collection_name: str, dimension: int):
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
),
)
print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
def _create_points(self, items: list[VectorItem]):
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={"text": item["text"], "metadata": item["metadata"]},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(
f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
return self.client.delete_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
query_response = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query=vectors[0],
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
if not self.has_collection(collection_name):
return None
try:
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query_filter=models.Filter(should=field_conditions),
limit=limit,
)
return self._result_to_get_result(points.points)
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points.points)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
field_conditions = []
if ids:
for id_value in ids:
field_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
),
elif filter:
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
),
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions)
),
)
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)

View file

@ -1,6 +1,7 @@
import inspect
import json
import logging
import time
from typing import AsyncGenerator, Generator, Iterator
from open_webui.apps.socket.main import get_event_call, get_event_emitter
@ -9,6 +10,7 @@ from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.routers import (
auths,
chats,
folders,
configs,
files,
functions,
@ -16,6 +18,7 @@ from open_webui.apps.webui.routers import (
models,
knowledge,
prompts,
evaluations,
tools,
users,
utils,
@ -31,10 +34,17 @@ from open_webui.config import (
ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
ENABLE_EVALUATION_ARENA_MODELS,
EVALUATION_ARENA_MODELS,
DEFAULT_ARENA_MODEL,
JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS,
USER_PERMISSIONS,
WEBHOOK_URL,
@ -89,10 +99,18 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
@ -107,19 +125,24 @@ app.add_middleware(
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@ -133,6 +156,47 @@ async def get_status():
}
async def get_all_models():
models = []
pipe_models = await get_pipe_models()
models = models + pipe_models
if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = []
if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [
{
"id": model["id"],
"name": model["name"],
"info": {
"meta": model["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
for model in app.state.config.EVALUATION_ARENA_MODELS
]
else:
# Add default arena model
arena_models = [
{
"id": DEFAULT_ARENA_MODEL["id"],
"name": DEFAULT_ARENA_MODEL["name"],
"info": {
"meta": DEFAULT_ARENA_MODEL["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
]
models = models + arena_models
return models
def get_function_module(pipe_id: str):
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:

View file

@ -4,8 +4,13 @@ import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Chat DB Schema
@ -18,13 +23,17 @@ class Chat(Base):
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(Text) # Save Chat JSON as Text
chat = Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True)
class ChatModel(BaseModel):
@ -33,13 +42,17 @@ class ChatModel(BaseModel):
id: str
user_id: str
title: str
chat: str
chat: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
share_id: Optional[str] = None
archived: bool = False
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
####################
@ -51,6 +64,17 @@ class ChatForm(BaseModel):
chat: dict
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
folder_id: Optional[str] = None
class ChatTitleMessagesForm(BaseModel):
title: str
messages: list[dict]
class ChatTitleForm(BaseModel):
title: str
@ -64,6 +88,9 @@ class ChatResponse(BaseModel):
created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
archived: bool
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
class ChatTitleIdResponse(BaseModel):
@ -86,7 +113,36 @@ class ChatTable:
if "title" in form_data.chat
else "New Chat"
),
"chat": json.dumps(form_data.chat),
"chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def import_chat(
self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
@ -101,14 +157,14 @@ class ChatTable:
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
with get_db() as db:
chat_obj = db.get(Chat, id)
chat_obj.chat = json.dumps(chat)
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
chat_obj.updated_at = int(time.time())
chat_item = db.get(Chat, id)
chat_item.chat = chat
chat_item.title = chat["title"] if "title" in chat else "New Chat"
chat_item.updated_at = int(time.time())
db.commit()
db.refresh(chat_obj)
db.refresh(chat_item)
return ChatModel.model_validate(chat_obj)
return ChatModel.model_validate(chat_item)
except Exception:
return None
@ -182,11 +238,24 @@ class ChatTable:
except Exception:
return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
@ -223,14 +292,18 @@ class ChatTable:
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
if not include_archived:
query = query.filter_by(archived=False)
all_chats = (
query.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_title_id_list_by_user_id(
@ -241,7 +314,9 @@ class ChatTable:
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived:
query = query.filter_by(archived=False)
@ -249,10 +324,10 @@ class ChatTable:
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
)
if limit:
query = query.limit(limit)
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
@ -328,6 +403,15 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
@ -337,6 +421,329 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text(
self,
user_id: str,
search_text: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
search_text_words = search_text.split(" ")
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
tag_ids = [
word.replace("tag:", "").replace(" ", "_").lower()
for word in search_text_words
if word.startswith("tag:")
]
search_text_words = [
word for word in search_text_words if not word.startswith("tag:")
]
search_text = " ".join(search_text_words)
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived:
query = query.filter(Chat.archived == False)
query = query.order_by(Chat.updated_at.desc())
# Check if the database dialect is either 'sqlite' or 'postgresql'
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
)
"""
)
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
WHERE tag.value = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
)
"""
)
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
WHERE tag = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
print(len(all_chats))
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter(
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.folder_id = folder_id
chat.updated_at = int(time.time())
chat.pinned = False
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name)
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
all_chats = query.all()
print("all_chats", all_chats)
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
try:
with get_db() as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Get the count of matching records
count = query.count()
# Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count)
return count
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower()
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": list(set(tags)),
}
db.commit()
return True
except Exception:
return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.meta = {
**chat.meta,
"tags": [],
}
db.commit()
return True
except Exception:
return False
def delete_chat_by_id(self, id: str) -> bool:
try:
with get_db() as db:
@ -369,6 +776,18 @@ class ChatTable:
except Exception:
return False
def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:

View file

@ -0,0 +1,254 @@
import logging
import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Feedback DB Schema
####################
class Feedback(Base):
__tablename__ = "feedback"
id = Column(Text, primary_key=True)
user_id = Column(Text)
version = Column(BigInteger, default=0)
type = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
snapshot = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class FeedbackModel(BaseModel):
id: str
user_id: str
version: int
type: str
data: Optional[dict] = None
meta: Optional[dict] = None
snapshot: Optional[dict] = None
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FeedbackResponse(BaseModel):
id: str
user_id: str
version: int
type: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int
updated_at: int
class RatingData(BaseModel):
rating: Optional[str | int] = None
model_id: Optional[str] = None
sibling_model_ids: Optional[list[str]] = None
reason: Optional[str] = None
comment: Optional[str] = None
model_config = ConfigDict(extra="allow", protected_namespaces=())
class MetaData(BaseModel):
arena: Optional[bool] = None
chat_id: Optional[str] = None
message_id: Optional[str] = None
tags: Optional[list[str]] = None
model_config = ConfigDict(extra="allow")
class SnapshotData(BaseModel):
chat: Optional[dict] = None
model_config = ConfigDict(extra="allow")
class FeedbackForm(BaseModel):
type: str
data: Optional[RatingData] = None
meta: Optional[dict] = None
snapshot: Optional[SnapshotData] = None
model_config = ConfigDict(extra="allow")
class FeedbackTable:
def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
id = str(uuid.uuid4())
feedback = FeedbackModel(
**{
"id": id,
"user_id": user_id,
"version": 0,
**form_data.model_dump(),
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Feedback(**feedback.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FeedbackModel.model_validate(result)
else:
return None
except Exception as e:
print(e)
return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
try:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return None
return FeedbackModel.model_validate(feedback)
except Exception:
return None
def get_feedback_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FeedbackModel]:
try:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return None
return FeedbackModel.model_validate(feedback)
except Exception:
return None
def get_all_feedbacks(self) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.order_by(Feedback.updated_at.desc())
.all()
]
def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(type=type)
.order_by(Feedback.updated_at.desc())
.all()
]
def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(user_id=user_id)
.order_by(Feedback.updated_at.desc())
.all()
]
def update_feedback_by_id(
self, id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return None
if form_data.data:
feedback.data = form_data.data.model_dump()
if form_data.meta:
feedback.meta = form_data.meta
if form_data.snapshot:
feedback.snapshot = form_data.snapshot.model_dump()
feedback.updated_at = int(time.time())
db.commit()
return FeedbackModel.model_validate(feedback)
def update_feedback_by_id_and_user_id(
self, id: str, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return None
if form_data.data:
feedback.data = form_data.data.model_dump()
if form_data.meta:
feedback.meta = form_data.meta
if form_data.snapshot:
feedback.snapshot = form_data.snapshot.model_dump()
feedback.updated_at = int(time.time())
db.commit()
return FeedbackModel.model_validate(feedback)
def delete_feedback_by_id(self, id: str) -> bool:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return False
db.delete(feedback)
db.commit()
return True
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return False
db.delete(feedback)
db.commit()
return True
def delete_feedbacks_by_user_id(self, user_id: str) -> bool:
with get_db() as db:
feedbacks = db.query(Feedback).filter_by(user_id=user_id).all()
if not feedbacks:
return False
for feedback in feedbacks:
db.delete(feedback)
db.commit()
return True
def delete_all_feedbacks(self) -> bool:
with get_db() as db:
feedbacks = db.query(Feedback).all()
if not feedbacks:
return False
for feedback in feedbacks:
db.delete(feedback)
db.commit()
return True
Feedbacks = FeedbackTable()

View file

@ -17,14 +17,15 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class File(Base):
__tablename__ = "file"
id = Column(String, primary_key=True)
user_id = Column(String)
hash = Column(Text, nullable=True)
filename = Column(Text)
path = Column(Text, nullable=True)
data = Column(JSON, nullable=True)
meta = Column(JSONField)
meta = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@ -38,8 +39,10 @@ class FileModel(BaseModel):
hash: Optional[str] = None
filename: str
path: Optional[str] = None
data: Optional[dict] = None
meta: dict
meta: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@ -50,6 +53,14 @@ class FileModel(BaseModel):
####################
class FileMeta(BaseModel):
name: Optional[str] = None
content_type: Optional[str] = None
size: Optional[int] = None
model_config = ConfigDict(extra="allow")
class FileModelResponse(BaseModel):
id: str
user_id: str
@ -57,16 +68,24 @@ class FileModelResponse(BaseModel):
filename: str
data: Optional[dict] = None
meta: dict
meta: FileMeta
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class FileMetadataResponse(BaseModel):
id: str
meta: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class FileForm(BaseModel):
id: str
hash: Optional[str] = None
filename: str
path: str
data: dict = {}
meta: dict = {}
@ -104,6 +123,19 @@ class FilesTable:
except Exception:
return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
try:
file = db.get(File, id)
return FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
except Exception:
return None
def get_files(self) -> list[FileModel]:
with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
@ -118,6 +150,21 @@ class FilesTable:
.all()
]
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
with get_db() as db:
return [
FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
for file in db.query(File)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
]
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
with get_db() as db:
return [

View file

@ -0,0 +1,271 @@
import logging
import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Folder DB Schema
####################
class Folder(Base):
__tablename__ = "folder"
id = Column(Text, primary_key=True)
parent_id = Column(Text, nullable=True)
user_id = Column(Text)
name = Column(Text)
items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class FolderModel(BaseModel):
id: str
parent_id: Optional[str] = None
user_id: str
name: str
items: Optional[dict] = None
meta: Optional[dict] = None
is_expanded: bool = False
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FolderForm(BaseModel):
name: str
model_config = ConfigDict(extra="allow")
class FolderTable:
def insert_new_folder(
self, user_id: str, name: str, parent_id: Optional[str] = None
) -> Optional[FolderModel]:
with get_db() as db:
id = str(uuid.uuid4())
folder = FolderModel(
**{
"id": id,
"user_id": user_id,
"name": name,
"parent_id": parent_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Folder(**folder.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FolderModel.model_validate(result)
else:
return None
except Exception as e:
print(e)
return None
def get_folder_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
return FolderModel.model_validate(folder)
except Exception:
return None
def get_children_folders_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folders = []
def get_children(folder):
children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id
)
for child in children:
get_children(child)
folders.append(child)
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
get_children(folder)
return folders
except Exception:
return None
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
with get_db() as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder).filter_by(user_id=user_id).all()
]
def get_folder_by_parent_id_and_user_id_and_name(
self, parent_id: Optional[str], user_id: str, name: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
# Check if folder exists
folder = (
db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id)
.filter(Folder.name.ilike(name))
.first()
)
if not folder:
return None
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
return None
def get_folders_by_parent_id_and_user_id(
self, parent_id: Optional[str], user_id: str
) -> list[FolderModel]:
with get_db() as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id)
.all()
]
def update_folder_parent_id_by_id_and_user_id(
self,
id: str,
user_id: str,
parent_id: str,
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
folder.parent_id = parent_id
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def update_folder_name_by_id_and_user_id(
self, id: str, user_id: str, name: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
existing_folder = (
db.query(Folder)
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
.first()
)
if existing_folder:
return None
folder.name = name
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def update_folder_is_expanded_by_id_and_user_id(
self, id: str, user_id: str, is_expanded: bool
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
folder.is_expanded = is_expanded
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return False
# Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
# Delete all children folders
def delete_children(folder):
folder_children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id
)
for folder_child in folder_children:
Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id
)
delete_children(folder_child)
folder = db.query(Folder).filter_by(id=folder_child.id).first()
db.delete(folder)
db.commit()
delete_children(folder)
db.delete(folder)
db.commit()
return True
except Exception as e:
log.error(f"delete_folder: {e}")
return False
Folders = FolderTable()

View file

@ -6,6 +6,10 @@ import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@ -64,6 +68,8 @@ class KnowledgeResponse(BaseModel):
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeForm(BaseModel):
name: str
@ -148,5 +154,15 @@ class KnowledgeTable:
except Exception:
return False
def delete_all_knowledge(self) -> bool:
with get_db() as db:
try:
db.query(Knowledge).delete()
db.commit()
return True
except Exception:
return False
Knowledges = KnowledgeTable()

View file

@ -4,53 +4,35 @@ import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tag DB Schema
####################
class Tag(Base):
__tablename__ = "tag"
id = Column(String, primary_key=True)
id = Column(String)
name = Column(String)
user_id = Column(String)
data = Column(Text, nullable=True)
meta = Column(JSON, nullable=True)
class ChatIdTag(Base):
__tablename__ = "chatidtag"
id = Column(String, primary_key=True)
tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
class TagModel(BaseModel):
id: str
name: str
user_id: str
data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel):
id: str
tag_name: str
chat_id: str
user_id: str
timestamp: int
meta: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
@ -59,23 +41,15 @@ class ChatIdTagModel(BaseModel):
####################
class ChatIdTagForm(BaseModel):
tag_name: str
class TagChatIdForm(BaseModel):
name: str
chat_id: str
class TagChatIdsResponse(BaseModel):
chat_ids: list[str]
class ChatTagsResponse(BaseModel):
tags: list[str]
class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db:
id = str(uuid.uuid4())
id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag(**tag.model_dump())
@ -86,177 +60,50 @@ class TagTable:
return TagModel.model_validate(result)
else:
return None
except Exception:
except Exception as e:
print(e)
return None
def get_tag_by_name_and_user_id(
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
id = name.replace(" ", "_").lower()
with get_db() as db:
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception:
return None
def add_tag_to_chat(
self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag is None:
tag = self.insert_new_tag(form_data.tag_name, user_id)
id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel(
**{
"id": id,
"user_id": user_id,
"chat_id": form_data.chat_id,
"tag_name": tag.name,
"timestamp": int(time.time()),
}
)
try:
with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ChatIdTagModel.model_validate(result)
else:
return None
except Exception:
return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
]
def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str
) -> list[TagModel]:
with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
)
]
def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> list[ChatIdTagModel]:
with get_db() as db:
return [
ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
with get_db() as db:
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
try:
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
)
id = name.replace(" ", "_").lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
)
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
tag.tag_name, chat_id, user_id
)
return True
Tags = TagTable()

View file

@ -1,10 +1,13 @@
import re
import uuid
import time
import datetime
from open_webui.apps.webui.models.auths import (
AddUserForm,
ApiKey,
Auths,
Token,
SigninForm,
SigninResponse,
SignupForm,
@ -18,6 +21,8 @@ from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
@ -32,6 +37,7 @@ from open_webui.utils.utils import (
get_password_hash,
)
from open_webui.utils.webhook import post_webhook
from typing import Optional
router = APIRouter()
@ -40,23 +46,44 @@ router = APIRouter()
############################
@router.get("/", response_model=UserResponse)
class SessionUserResponse(Token, UserResponse):
expires_at: Optional[int] = None
@router.get("/", response_model=SessionUserResponse)
async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user)
):
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
expires_delta=expires_delta,
)
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
"token": token,
"token_type": "Bearer",
"expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
@ -115,7 +142,7 @@ async def update_password(
############################
@router.post("/signin", response_model=SigninResponse)
@router.post("/signin", response_model=SessionUserResponse)
async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
@ -157,21 +184,37 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
expires_delta=expires_delta,
)
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
"token": token,
"token_type": "Bearer",
"expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
@ -187,7 +230,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
############################
@router.post("/signup", response_model=SigninResponse)
@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
if WEBUI_AUTH:
if (
@ -227,16 +270,30 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
)
if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
expires_delta=expires_delta,
)
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
if request.app.state.config.WEBHOOK_URL:
@ -253,6 +310,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
return {
"token": token,
"token_type": "Bearer",
"expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
@ -265,6 +323,12 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
@router.get("/signout")
async def signout(response: Response):
response.delete_cookie("token")
return {"status": True}
############################
# AddUser
############################

View file

@ -4,16 +4,14 @@ from typing import Optional
from open_webui.apps.webui.models.chats import (
ChatForm,
ChatImportForm,
ChatResponse,
Chats,
ChatTitleIdResponse,
)
from open_webui.apps.webui.models.tags import (
ChatIdTagForm,
ChatIdTagModel,
TagModel,
Tags,
)
from open_webui.apps.webui.models.tags import TagModel, Tags
from open_webui.apps.webui.models.folders import Folders
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
@ -95,7 +93,35 @@ async def get_user_chat_list_by_user_id(
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# ImportChat
############################
@router.post("/import", response_model=Optional[ChatResponse])
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
try:
chat = Chats.import_chat(user.id, form_data)
if chat:
tags = chat.meta.get("tags", [])
for tag_id in tags:
tag_id = tag_id.replace(" ", "_").lower()
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
if (
tag_id != "none"
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
):
Tags.insert_new_tag(tag_name, user.id)
return ChatResponse(**chat.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
@ -108,10 +134,77 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
############################
@router.get("/search", response_model=list[ChatTitleIdResponse])
async def search_user_chats(
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
):
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
chat_list = [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit
)
]
# Delete tag if no chat is found
words = text.strip().split(" ")
if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
tag_id = words[0].replace("tag:", "")
if len(chat_list) == 0:
if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
return chat_list
############################
# GetChatsByFolderId
############################
@router.get("/folder/{folder_id}", response_model=list[ChatResponse])
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
folder_ids = [folder_id]
children_folders = Folders.get_children_folders_by_id_and_user_id(
folder_id, user.id
)
if children_folders:
folder_ids.extend([folder.id for folder in children_folders])
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
]
############################
# GetPinnedChats
############################
@router.get("/pinned", response_model=list[ChatResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
]
############################
# GetChats
############################
@router.get("/all", response_model=list[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id(user.id)
]
@ -124,11 +217,28 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=list[ChatResponse])
async def get_user_archived_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
ChatResponse(**chat.model_dump())
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
############################
# GetAllTags
############################
@router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetAllChatsInDB
############################
@ -141,10 +251,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats()
]
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
############################
@ -187,7 +294,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -199,48 +307,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
############################
class TagNameForm(BaseModel):
class TagForm(BaseModel):
name: str
class TagFilterForm(TagForm):
skip: Optional[int] = 0
limit: Optional[int] = 50
@router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user)
form_data: TagFilterForm, user=Depends(get_verified_user)
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
chats = Chats.get_chat_list_by_user_id_and_tag_name(
user.id, form_data.name, form_data.skip, form_data.limit
)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=list[TagModel])
async def get_all_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatById
############################
@ -251,7 +339,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -269,10 +358,9 @@ async def update_chat_by_id(
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
updated_chat = {**chat.chat, **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -288,7 +376,13 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role == "admin":
chat = Chats.get_chat_by_id(id)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
result = Chats.delete_chat_by_id(id)
return result
else:
if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
@ -299,29 +393,66 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.get_chat_by_id(id)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
############################
# GetPinnedStatusById
############################
@router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return chat.pinned
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# PinChatById
############################
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_pinned_by_id(id)
return chat
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# CloneChat
############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat_body = json.loads(chat.chat)
updated_chat = {
**chat_body,
**chat.chat,
"originalChatId": chat.id,
"branchPointMessageId": chat_body["history"]["currentId"],
"branchPointMessageId": chat.chat["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -333,12 +464,26 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
# Delete tags if chat is archived
if chat.archived:
for tag_id in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
else:
for tag_id in chat.meta.get("tags", []):
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
if tag is None:
log.debug(f"inserting tag: {tag_id}")
tag = Tags.insert_new_tag(tag_id, user.id)
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -356,9 +501,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
return ChatResponse(**shared_chat.model_dump())
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat:
@ -366,10 +509,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
return ChatResponse(**shared_chat.model_dump())
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -400,6 +541,31 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
)
############################
# UpdateChatFolderIdById
############################
class ChatFolderIdForm(BaseModel):
folder_id: Optional[str] = None
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
async def update_chat_folder_id_by_id(
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.update_chat_folder_id_by_id_and_user_id(
id, user.id, form_data.folder_id
)
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatTagsById
############################
@ -407,10 +573,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None:
return tags
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -422,22 +588,30 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
@router.post("/{id}/tags", response_model=list[TagModel])
async def add_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower()
if form_data.tag_name not in tags:
tag = Tags.add_tag_to_chat(user.id, form_data)
if tag:
return tag
else:
if tag_id == "none":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
)
print(tags, tag_id)
if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -449,16 +623,20 @@ async def add_chat_tag_by_id(
############################
@router.delete("/{id}/tags", response_model=Optional[bool])
async def delete_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
@router.delete("/{id}/tags", response_model=list[TagModel])
async def delete_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
form_data.tag_name, id, user.id
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
if result:
return result
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -466,16 +644,21 @@ async def delete_chat_tag_by_id(
############################
# DeleteAllChatTagsById
# DeleteAllTagsById
############################
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
if result:
return result
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
return True
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND

View file

@ -0,0 +1,147 @@
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from open_webui.apps.webui.models.users import Users, UserModel
from open_webui.apps.webui.models.feedbacks import (
FeedbackModel,
FeedbackForm,
Feedbacks,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.utils import get_admin_user, get_verified_user
router = APIRouter()
############################
# GetConfig
############################
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_EVALUATION_ARENA_MODELS": request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS,
"EVALUATION_ARENA_MODELS": request.app.state.config.EVALUATION_ARENA_MODELS,
}
############################
# UpdateConfig
############################
class UpdateConfigForm(BaseModel):
ENABLE_EVALUATION_ARENA_MODELS: Optional[bool] = None
EVALUATION_ARENA_MODELS: Optional[list[dict]] = None
@router.post("/config")
async def update_config(
request: Request,
form_data: UpdateConfigForm,
user=Depends(get_admin_user),
):
config = request.app.state.config
if form_data.ENABLE_EVALUATION_ARENA_MODELS is not None:
config.ENABLE_EVALUATION_ARENA_MODELS = form_data.ENABLE_EVALUATION_ARENA_MODELS
if form_data.EVALUATION_ARENA_MODELS is not None:
config.EVALUATION_ARENA_MODELS = form_data.EVALUATION_ARENA_MODELS
return {
"ENABLE_EVALUATION_ARENA_MODELS": config.ENABLE_EVALUATION_ARENA_MODELS,
"EVALUATION_ARENA_MODELS": config.EVALUATION_ARENA_MODELS,
}
@router.get("/feedbacks", response_model=list[FeedbackModel])
async def get_feedbacks(user=Depends(get_verified_user)):
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id)
return feedbacks
@router.delete("/feedbacks", response_model=bool)
async def delete_feedbacks(user=Depends(get_verified_user)):
success = Feedbacks.delete_feedbacks_by_user_id(user.id)
return success
class FeedbackUserModel(FeedbackModel):
user: Optional[UserModel] = None
@router.get("/feedbacks/all", response_model=list[FeedbackUserModel])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
return [
FeedbackUserModel(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
)
for feedback in feedbacks
]
@router.delete("/feedbacks/all")
async def delete_all_feedbacks(user=Depends(get_admin_user)):
success = Feedbacks.delete_all_feedbacks()
return success
@router.post("/feedback", response_model=FeedbackModel)
async def create_feedback(
request: Request,
form_data: FeedbackForm,
user=Depends(get_verified_user),
):
feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data)
if not feedback:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
return feedback
@router.get("/feedback/{id}", response_model=FeedbackModel)
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
if not feedback:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
return feedback
@router.post("/feedback/{id}", response_model=FeedbackModel)
async def update_feedback_by_id(
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
):
feedback = Feedbacks.update_feedback_by_id_and_user_id(
id=id, user_id=user.id, form_data=form_data
)
if not feedback:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
return feedback
@router.delete("/feedback/{id}")
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)):
if user.role == "admin":
success = Feedbacks.delete_feedback_by_id(id=id)
else:
success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
return success

View file

@ -1,14 +1,19 @@
import logging
import os
import shutil
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.files import FileForm, FileModel, Files
from open_webui.apps.webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR
@ -44,24 +49,19 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
id = str(uuid.uuid4())
name = filename
filename = f"{id}_{filename}"
file_path = f"{UPLOAD_DIR}/{filename}"
contents, file_path = Storage.upload_file(file.file, filename)
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
file = Files.insert_new_file(
file_item = Files.insert_new_file(
user.id,
FileForm(
**{
"id": id,
"filename": filename,
"path": file_path,
"meta": {
"name": name,
"content_type": file.content_type,
"size": len(contents),
"path": file_path,
},
}
),
@ -69,13 +69,13 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
try:
process_file(ProcessFileForm(file_id=id))
file = Files.get_file_by_id(id=id)
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
log.error(f"Error processing file: {file.id}")
log.error(f"Error processing file: {file_item.id}")
if file:
return file
if file_item:
return file_item
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -95,7 +95,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
############################
@router.get("/", response_model=list[FileModel])
@router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user)):
if user.role == "admin":
files = Files.get_files()
@ -112,27 +112,16 @@ async def list_files(user=Depends(get_verified_user)):
@router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files()
if result:
folder = f"{UPLOAD_DIR}"
try:
# Check if the directory exists
if os.path.exists(folder):
# Iterate over all the files and directories in the specified directory
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"The directory {folder} does not exist")
Storage.delete_all_files()
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
log.exception(e)
log.error(f"Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "All files deleted successfully"}
else:
raise HTTPException(
@ -213,24 +202,32 @@ async def update_file_data_content_by_id(
############################
@router.get("/{id}/content", response_model=Optional[FileModel])
@router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
file_path = Path(file.meta["path"])
try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
)
else:
raise HTTPException(
@ -239,19 +236,23 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
@router.get("/{id}/content/{file_name}")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
file_path = file.meta.get("path")
file_path = file.path
if file_path:
file_path = Storage.get_file(file_path)
file_path = Path(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
return FileResponse(file_path)
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -289,6 +290,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"):
result = Files.delete_file_by_id(id)
if result:
try:
Storage.delete_file(file.filename)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "File deleted successfully"}
else:
raise HTTPException(

View file

@ -0,0 +1,251 @@
import logging
import os
import shutil
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from open_webui.apps.webui.models.folders import (
FolderForm,
FolderModel,
Folders,
)
from open_webui.apps.webui.models.chats import Chats
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# Get Folders
############################
@router.get("/", response_model=list[FolderModel])
async def get_folders(user=Depends(get_verified_user)):
folders = Folders.get_folders_by_user_id(user.id)
return [
{
**folder.model_dump(),
"items": {
"chats": [
{"title": chat.title, "id": chat.id}
for chat in Chats.get_chats_by_folder_id_and_user_id(
folder.id, user.id
)
]
},
}
for folder in folders
]
############################
# Create Folder
############################
@router.post("/")
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
None, user.id, form_data.name
)
if folder:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
folder = Folders.insert_new_folder(user.id, form_data.name)
return folder
except Exception as e:
log.exception(e)
log.error("Error creating folder")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating folder"),
)
############################
# Get Folders By Id
############################
@router.get("/{id}", response_model=Optional[FolderModel])
async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
return folder
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Update Folder Name By Id
############################
@router.post("/{id}/update")
async def update_folder_name_by_id(
id: str, form_data: FolderForm, user=Depends(get_verified_user)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
folder.parent_id, user.id, form_data.name
)
if existing_folder:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
folder = Folders.update_folder_name_by_id_and_user_id(
id, user.id, form_data.name
)
return folder
except Exception as e:
log.exception(e)
log.error(f"Error updating folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Update Folder Parent Id By Id
############################
class FolderParentIdForm(BaseModel):
parent_id: Optional[str] = None
@router.post("/{id}/update/parent")
async def update_folder_parent_id_by_id(
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
form_data.parent_id, user.id, folder.name
)
if existing_folder:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
folder = Folders.update_folder_parent_id_by_id_and_user_id(
id, user.id, form_data.parent_id
)
return folder
except Exception as e:
log.exception(e)
log.error(f"Error updating folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Update Folder Is Expanded By Id
############################
class FolderIsExpandedForm(BaseModel):
is_expanded: bool
@router.post("/{id}/update/expanded")
async def update_folder_is_expanded_by_id(
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
try:
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
id, user.id, form_data.is_expanded
)
return folder
except Exception as e:
log.exception(e)
log.error(f"Error updating folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Delete Folder By Id
############################
@router.delete("/{id}")
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
try:
result = Folders.delete_folder_by_id_and_user_id(id, user.id)
if result:
return result
else:
raise Exception("Error deleting folder")
except Exception as e:
log.exception(e)
log.error(f"Error deleting folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)

View file

@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import (
Functions,
)
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user

View file

@ -48,7 +48,12 @@ async def get_knowledge_items(
)
else:
return [
KnowledgeResponse(**knowledge.model_dump())
KnowledgeResponse(
**knowledge.model_dump(),
files=Files.get_file_metadatas_by_ids(
knowledge.data.get("file_ids", []) if knowledge.data else []
),
)
for knowledge in Knowledges.get_knowledge_items()
]

View file

@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs
from open_webui.utils.utils import get_admin_user, get_verified_user
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()

View file

@ -1,16 +1,14 @@
import site
from pathlib import Path
import black
import markdown
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.env import FONTS_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fpdf import FPDF
from pydantic import BaseModel
from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.utils import get_admin_user
router = APIRouter()
@ -56,58 +54,19 @@ class ChatForm(BaseModel):
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatForm,
form_data: ChatTitleMessagesForm,
):
global FONTS_DIR
try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
pdf = FPDF()
pdf.add_page()
# When running using `pip install` the static directory is in the site packages.
if not FONTS_DIR.exists():
FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts"
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf")
pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf")
pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf")
pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf")
pdf.set_font("NotoSans", size=12)
pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"])
pdf.set_auto_page_break(auto=True, margin=15)
# Adjust the effective page width for multi_cell
effective_page_width = (
pdf.w - 2 * pdf.l_margin - 10
) # Subtracted an additional 10 for extra padding
# Add chat messages
for message in form_data.messages:
role = message["role"]
content = message["content"]
pdf.set_font("NotoSans", "B", size=14) # Bold for the role
pdf.multi_cell(effective_page_width, 10, f"{role.upper()}", 0, "L")
pdf.ln(1) # Extra space between messages
pdf.set_font("NotoSans", size=10) # Regular for content
pdf.multi_cell(effective_page_width, 6, content, 0, "L")
pdf.ln(1.5) # Extra space between messages
# Save the pdf with name .pdf
pdf_bytes = pdf.output()
return Response(
content=bytes(pdf_bytes),
media_type="application/pdf",
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
return Response(
content=pdf_bytes,
media_type="application/pdf",
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail=str(e))
@router.get("/db/download")

View file

@ -8,7 +8,6 @@ import tempfile
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.tools import Tools
from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR
def extract_frontmatter(content):

View file

@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig(
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"OAUTH_PICTURE_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
@ -394,6 +394,33 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
)
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_ROLE_MANAGEMENT",
"oauth.enable_role_mapping",
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
"OAUTH_ROLES_CLAIM",
"oauth.roles_claim",
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
)
OAUTH_ALLOWED_ROLES = PersistentConfig(
"OAUTH_ALLOWED_ROLES",
"oauth.allowed_roles",
[
role.strip()
for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")
],
)
OAUTH_ADMIN_ROLES = PersistentConfig(
"OAUTH_ADMIN_ROLES",
"oauth.admin_roles",
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
@ -506,6 +533,18 @@ if CUSTOM_NAME:
pass
####################################
# STORAGE PROVIDER
####################################
STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
####################################
# File Upload DIR
####################################
@ -521,26 +560,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# OLLAMA_BASE_URL
####################################
ENABLE_OLLAMA_API = PersistentConfig(
"ENABLE_OLLAMA_API",
"ollama.enable",
@ -728,6 +751,28 @@ USER_PERMISSIONS = PersistentConfig(
},
)
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
"ENABLE_EVALUATION_ARENA_MODELS",
"evaluation.arena.enable",
os.environ.get("ENABLE_EVALUATION_ARENA_MODELS", "True").lower() == "true",
)
EVALUATION_ARENA_MODELS = PersistentConfig(
"EVALUATION_ARENA_MODELS",
"evaluation.arena.models",
[],
)
DEFAULT_ARENA_MODEL = {
"id": "arena-model",
"name": "Arena Model",
"meta": {
"profile_image_url": "/favicon.png",
"description": "Submit your questions to anonymous AI chatbots and vote on the best response.",
"model_ids": None,
},
}
ENABLE_MODEL_FILTER = PersistentConfig(
"ENABLE_MODEL_FILTER",
"model_filter.enable",
@ -853,6 +898,12 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
)
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TAGS_GENERATION_PROMPT_TEMPLATE",
"task.tags.prompt_template",
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
)
ENABLE_SEARCH_QUERY = PersistentConfig(
"ENABLE_SEARCH_QUERY",
"task.search.enable",
@ -901,6 +952,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
# Qdrant
QDRANT_URI = os.environ.get("QDRANT_URI", None)
####################################
# Information Retrieval (RAG)
####################################
@ -986,10 +1040,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
"rag.embedding_openai_batch_size",
int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_BATCH_SIZE",
"rag.embedding_batch_size",
int(
os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
),
)
RAG_RERANKING_MODEL = PersistentConfig(
@ -1008,6 +1065,22 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_TEXT_SPLITTER = PersistentConfig(
"RAG_TEXT_SPLITTER",
"rag.text_splitter",
os.environ.get("RAG_TEXT_SPLITTER", ""),
)
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
TIKTOKEN_ENCODING_NAME = PersistentConfig(
"TIKTOKEN_ENCODING_NAME",
"rag.tiktoken_encoding_name",
os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"),
)
CHUNK_SIZE = PersistentConfig(
"CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000"))
)
@ -1020,7 +1093,7 @@ CHUNK_OVERLAP = PersistentConfig(
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
<context>
[context]
{{CONTEXT}}
</context>
<rules>
@ -1033,7 +1106,7 @@ DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and r
</rules>
<user_query>
[query]
{{QUERY}}
</user_query>
"""
@ -1168,17 +1241,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
)
####################################
# Transcribe
####################################
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
####################################
# Images
####################################
@ -1394,6 +1456,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio
####################################
# Transcription
WHISPER_MODEL = PersistentConfig(
"WHISPER_MODEL",
"audio.stt.whisper_model",
os.getenv("WHISPER_MODEL", "base"),
)
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url",
@ -1415,7 +1490,7 @@ AUDIO_STT_ENGINE = PersistentConfig(
AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL",
"audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
os.getenv("AUDIO_STT_MODEL", ""),
)
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(

View file

@ -20,7 +20,9 @@ class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}"
DEFAULT = (
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
)
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
@ -106,6 +108,7 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
FUNCTION_CALLING = "function_calling"

View file

@ -230,6 +230,8 @@ if FROM_INIT_PY:
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
@ -302,6 +304,12 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
####################################
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
####################################
# WEBUI_AUTH (Required for security)
####################################
@ -343,8 +351,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@ -355,3 +362,23 @@ else:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
)
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
####################################
# OFFLINE_MODE
####################################
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"

View file

@ -1,4 +1,4 @@
import base64
import asyncio
import inspect
import json
import logging
@ -7,20 +7,38 @@ import os
import shutil
import sys
import time
import uuid
import asyncio
import random
from contextlib import asynccontextmanager
from typing import Optional
import aiohttp
import requests
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import (
app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
GenerateChatCompletionForm,
)
from open_webui.apps.openai.main import (
@ -28,38 +46,24 @@ from open_webui.apps.openai.main import (
generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
)
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion,
get_pipe_models,
get_all_models as get_open_webui_models,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
@ -67,13 +71,11 @@ from open_webui.config import (
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OAUTH_SIGNUP,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
@ -81,15 +83,15 @@ from open_webui.config import (
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
reset_config,
)
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
from open_webui.constants import TASKS
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
@ -102,64 +104,41 @@ from open_webui.env import (
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
)
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse, Response, StreamingResponse
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
parse_duration,
prepend_to_first_user_message_content,
)
from open_webui.utils.task import (
moa_response_generation_template,
search_query_generation_template,
title_generation_template,
tools_function_calling_generation_template,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.utils import (
create_token,
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_password_hash,
get_verified_user,
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import (
moa_response_generation_template,
tags_generation_template,
search_query_generation_template,
emoji_generation_template,
title_generation_template,
tools_function_calling_generation_template,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.utils import (
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_verified_user,
)
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@ -178,14 +157,14 @@ class SPAStaticFiles(StaticFiles):
print(
rf"""
___ __ __ _ _ _ ___
___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|_|
|_|
v{VERSION} - building the best open-source AI user interface.
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
https://github.com/open-webui/open-webui
@ -216,10 +195,10 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
@ -577,7 +556,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
}
# Initialize data_items to store additional data to be sent to the client
# Initalize contexts and citation
# Initialize contexts and citation
data_items = []
contexts = []
citations = []
@ -689,6 +668,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
@ -824,6 +804,32 @@ class PipelineMiddleware(BaseHTTPMiddleware):
app.add_middleware(PipelineMiddleware)
from urllib.parse import urlencode, parse_qs, urlparse
class RedirectMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Check if the request is a GET request
if request.method == "GET":
path = request.url.path
query_params = dict(parse_qs(urlparse(str(request.url)).query))
# Check for the specific watch path and the presence of 'v' parameter
if path.endswith("/watch") and "v" in query_params:
video_id = query_params["v"][0] # Extract the first 'v' parameter
encoded_video_id = urlencode({"youtube": video_id})
redirect_url = f"/?{encoded_video_id}"
return RedirectResponse(url=redirect_url)
# Proceed with the normal flow of other requests
response = await call_next(request)
return response
# Add the middleware to the app
app.add_middleware(RedirectMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@ -900,12 +906,10 @@ webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
async def get_all_models():
# TODO: Optimize this function
pipe_models = []
open_webui_models = []
openai_models = []
ollama_models = []
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
openai_models = openai_models["data"]
@ -924,7 +928,13 @@ async def get_all_models():
for model in ollama_models["models"]
]
models = pipe_models + openai_models + ollama_models
open_webui_models = await get_open_webui_models()
models = open_webui_models + openai_models + ollama_models
# If there are no models, return an empty list
if len([model for model in models if model["owned_by"] != "arena"]) == 0:
return []
global_action_ids = [
function.id for function in Functions.get_global_action_functions()
@ -963,11 +973,13 @@ async def get_all_models():
owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
if "info" in model and "meta" in model["info"]:
action_ids.extend(model["info"]["meta"].get("actionIds", []))
break
if custom_model.meta:
meta = custom_model.meta.model_dump()
if "actionIds" in meta:
action_ids.extend(meta["actionIds"])
models.append(
{
"id": custom_model.id,
@ -1070,7 +1082,9 @@ async def get_models(user=Depends(get_verified_user)):
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
async def generate_chat_completions(
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
@ -1079,7 +1093,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
detail="Model not found",
)
if app.state.config.ENABLE_MODEL_FILTER:
if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@ -1087,6 +1101,53 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
)
model = app.state.MODELS[model_id]
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena"
and not model.get("info", {}).get("meta", {}).get("hidden", False)
and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena"
and not model.get("info", {}).get("meta", {}).get("hidden", False)
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completions(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
)
else:
return {
**(
await generate_chat_completions(form_data, user, bypass_filter=True)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
@ -1398,6 +1459,7 @@ async def get_task_config(user=Depends(get_verified_user)):
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@ -1408,6 +1470,7 @@ class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
TAGS_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
ENABLE_SEARCH_QUERY: bool
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@ -1420,6 +1483,10 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
@ -1432,6 +1499,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@ -1459,7 +1527,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else:
template = """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
@ -1469,11 +1537,13 @@ Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights
Prompt: {{prompt:middletruncate:8000}}"""
<chat_history>
{{MESSAGES:END:2}}
</chat_history>"""
content = title_generation_template(
template,
form_data["prompt"],
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
@ -1516,6 +1586,75 @@ Prompt: {{prompt:middletruncate:8000}}"""
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tags/completions")
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
print("generate_chat_tags")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id)
print(task_model_id)
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
else:
template = """### Task:
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
### Guidelines:
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
- Use the chat's primary language; default to English if multilingual
- Prioritize accuracy over specificity
### Output:
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
content = tags_generation_template(
template, form_data["messages"], {"name": user.name}
)
print("content", content)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
}
log.debug(payload)
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")
@ -1616,7 +1755,7 @@ Your task is to reflect the speaker's likely facial expression through a fitting
Message: """{{prompt}}"""
'''
content = title_generation_template(
content = emoji_generation_template(
template,
form_data["prompt"],
{
@ -2181,6 +2320,11 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
if OFFLINE_MODE:
log.debug(
f"Offline mode is enabled, returning current version as latest version"
)
return {"current": VERSION, "latest": VERSION}
try:
timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
@ -2201,20 +2345,6 @@ async def get_app_latest_release_version():
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
@ -2228,16 +2358,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = oauth.create_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
return await oauth_manager.handle_login(provider, request)
# OAuth login logic is as follows:
@ -2245,119 +2366,10 @@ async def oauth_login(provider: str, request: Request):
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
# - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}")
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = (
"admin"
if Users.get_num_users() == 0
else webui_app.state.config.DEFAULT_USER_ROLE
)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
return await oauth_manager.handle_callback(provider, request, response)
@app.get("/manifest.json")
@ -2416,6 +2428,7 @@ async def healthcheck_with_db():
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
if os.path.exists(FRONTEND_BUILD_DIR):
mimetypes.add_type("text/javascript", ".js")
app.mount(

View file

@ -0,0 +1,151 @@
"""Migrate tags
Revision ID: 1af9b942657b
Revises: 242a2047eae0
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
revision = "1af9b942657b"
down_revision = "242a2047eae0"
branch_labels = None
depends_on = None
def upgrade():
# Setup an inspection on the existing table to avoid issues
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
# Clean up potential leftover temp table from previous failures
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
# Check if the 'tag' table exists
tables = inspector.get_table_names()
# Step 1: Modify Tag table using batch mode for SQLite support
if "tag" in tables:
# Get the current columns in the 'tag' table
columns = [col["name"] for col in inspector.get_columns("tag")]
# Get any existing unique constraints on the 'tag' table
current_constraints = inspector.get_unique_constraints("tag")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Check if the unique constraint already exists
if not any(
constraint["name"] == "uq_id_user_id"
for constraint in current_constraints
):
# Create unique constraint if it doesn't exist
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
# Check if the 'data' column exists before trying to drop it
if "data" in columns:
batch_op.drop_column("data")
# Check if the 'meta' column needs to be created
if "meta" not in columns:
# Add the 'meta' column if it doesn't already exist
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
tag = table(
"tag",
column("id", sa.String()),
column("name", sa.String()),
column("user_id", sa.String()),
column("meta", sa.JSON()),
)
# Step 2: Migrate tags
conn = op.get_bind()
result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
tag_updates = {}
for row in result:
new_id = row.name.replace(" ", "_").lower()
tag_updates[row.id] = new_id
for tag_id, new_tag_id in tag_updates.items():
print(f"Updating tag {tag_id} to {new_tag_id}")
if new_tag_id == "pinned":
# delete tag
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
# Check if the new_tag_id already exists in the database
existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id)
existing_tag_result = conn.execute(existing_tag_query).fetchone()
if existing_tag_result:
# Handle duplicate case: the new_tag_id already exists
print(
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
)
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
update_stmt = sa.update(tag).where(tag.c.id == tag_id)
update_stmt = update_stmt.values(id=new_tag_id)
conn.execute(update_stmt)
# Add columns `pinned` and `meta` to 'chat'
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
op.add_column(
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
)
chatidtag = table(
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
)
chat = table(
"chat",
column("id", sa.String()),
column("pinned", sa.Boolean()),
column("meta", sa.JSON()),
)
# Fetch existing tags
conn = op.get_bind()
result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
chat_updates = {}
for row in result:
chat_id = row.chat_id
tag_name = row.tag_name.replace(" ", "_").lower()
if tag_name == "pinned":
# Specifically handle 'pinned' tag
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": True, "meta": {}}
else:
chat_updates[chat_id]["pinned"] = True
else:
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
else:
tags = chat_updates[chat_id]["meta"].get("tags", [])
tags.append(tag_name)
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
# Update chats based on accumulated changes
for chat_id, updates in chat_updates.items():
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
update_stmt = update_stmt.values(
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
)
conn.execute(update_stmt)
pass
def downgrade():
pass

View file

@ -0,0 +1,82 @@
"""Update chat table
Revision ID: 242a2047eae0
Revises: 6a39f3d8e55c
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update
import json
revision = "242a2047eae0"
down_revision = "6a39f3d8e55c"
branch_labels = None
depends_on = None
def upgrade():
# Step 1: Rename current 'chat' column to 'old_chat'
op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text)
# Step 2: Add new 'chat' column of type JSON
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
# Step 3: Migrate data from 'old_chat' to 'chat'
chat_table = table(
"chat",
sa.Column("id", sa.String, primary_key=True),
sa.Column("old_chat", sa.Text),
sa.Column("chat", sa.JSON()),
)
# - Selecting all data from the table
connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat))
for row in results:
try:
# Convert text JSON to actual JSON object, assuming the text is in JSON format
json_data = json.loads(row.old_chat)
except json.JSONDecodeError:
json_data = None # Handle cases where the text cannot be converted to JSON
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(chat=json_data)
)
# Step 4: Drop 'old_chat' column
op.drop_column("chat", "old_chat")
def downgrade():
# Step 1: Add 'old_chat' column back as Text
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
chat_table = table(
"chat",
sa.Column("id", sa.String, primary_key=True),
sa.Column("chat", sa.JSON()),
sa.Column("old_chat", sa.Text()),
)
connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
for row in results:
text_data = json.dumps(row.chat) if row.chat is not None else None
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(old_chat=text_data)
)
# Step 3: Remove the new 'chat' JSON column
op.drop_column("chat", "chat")
# Step 4: Rename 'old_chat' back to 'chat'
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text)

View file

@ -0,0 +1,81 @@
"""Update tags
Revision ID: 3ab32c4b8f59
Revises: 1af9b942657b
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
revision = "3ab32c4b8f59"
down_revision = "1af9b942657b"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
# Inspecting the 'tag' table constraints and structure
existing_pk = inspector.get_pk_constraint("tag")
unique_constraints = inspector.get_unique_constraints("tag")
existing_indexes = inspector.get_indexes("tag")
print(f"Primary Key: {existing_pk}")
print(f"Unique Constraints: {unique_constraints}")
print(f"Indexes: {existing_indexes}")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Drop existing primary key constraint if it exists
if existing_pk and existing_pk.get("constrained_columns"):
pk_name = existing_pk.get("name")
if pk_name:
print(f"Dropping primary key constraint: {pk_name}")
batch_op.drop_constraint(pk_name, type_="primary")
# Now create the new primary key with the combination of 'id' and 'user_id'
print("Creating new primary key with 'id' and 'user_id'.")
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
# Drop unique constraints that could conflict with the new primary key
for constraint in unique_constraints:
if (
constraint["name"] == "uq_id_user_id"
): # Adjust this name according to what is actually returned by the inspector
print(f"Dropping unique constraint: {constraint['name']}")
batch_op.drop_constraint(constraint["name"], type_="unique")
for index in existing_indexes:
if index["unique"]:
if not any(
constraint["name"] == index["name"]
for constraint in unique_constraints
):
# You are attempting to drop unique indexes
print(f"Dropping unique index: {index['name']}")
batch_op.drop_index(index["name"])
def downgrade():
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
current_pk = inspector.get_pk_constraint("tag")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Drop the current primary key first, if it matches the one we know we added in upgrade
if current_pk and "pk_id_user_id" == current_pk.get("name"):
batch_op.drop_constraint("pk_id_user_id", type_="primary")
# Restore the original primary key
batch_op.create_primary_key("pk_id", ["id"])
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])

View file

@ -0,0 +1,67 @@
"""Update folder table and change DateTime to BigInteger for timestamp fields
Revision ID: 4ace53fd72c8
Revises: af906e964978
Create Date: 2024-10-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "4ace53fd72c8"
down_revision = "af906e964978"
branch_labels = None
depends_on = None
def upgrade():
# Perform safe alterations using batch operation
with op.batch_alter_table("folder", schema=None) as batch_op:
# Step 1: Remove server defaults for created_at and updated_at
batch_op.alter_column(
"created_at",
server_default=None, # Removing server default
)
batch_op.alter_column(
"updated_at",
server_default=None, # Removing server default
)
# Step 2: Change the column types to BigInteger for created_at
batch_op.alter_column(
"created_at",
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL
)
# Change the column types to BigInteger for updated_at
batch_op.alter_column(
"updated_at",
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL
)
def downgrade():
# Downgrade: Convert columns back to DateTime and restore defaults
with op.batch_alter_table("folder", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade
)
batch_op.alter_column(
"updated_at",
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade
onupdate=sa.func.now(), # Restoring onupdate behavior if it was there
)

View file

@ -0,0 +1,51 @@
"""Add feedback table
Revision ID: af906e964978
Revises: c29facfe716b
Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers, used by Alembic.
revision = "af906e964978"
down_revision = "c29facfe716b"
branch_labels = None
depends_on = None
def upgrade():
# ### Create feedback table ###
op.create_table(
"feedback",
sa.Column(
"id", sa.Text(), primary_key=True
), # Unique identifier for each feedback (TEXT type)
sa.Column(
"user_id", sa.Text(), nullable=True
), # ID of the user providing the feedback (TEXT type)
sa.Column(
"version", sa.BigInteger(), default=0
), # Version of feedback (BIGINT type)
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
sa.Column(
"meta", sa.JSON(), nullable=True
), # Metadata for feedback (JSON type)
sa.Column(
"snapshot", sa.JSON(), nullable=True
), # snapshot data for feedback (JSON type)
sa.Column(
"created_at", sa.BigInteger(), nullable=False
), # Feedback creation timestamp (BIGINT representing epoch)
sa.Column(
"updated_at", sa.BigInteger(), nullable=False
), # Feedback update timestamp (BIGINT representing epoch)
)
def downgrade():
# ### Drop feedback table ###
op.drop_table("feedback")

View file

@ -0,0 +1,79 @@
"""Update file table path
Revision ID: c29facfe716b
Revises: c69f45358db4
Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
import json
from sqlalchemy.sql import table, column
from sqlalchemy import String, Text, JSON, and_
revision = "c29facfe716b"
down_revision = "c69f45358db4"
branch_labels = None
depends_on = None
def upgrade():
# 1. Add the `path` column to the "file" table.
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
# Use Alembic's default batch_op for dialect compatibility.
with op.batch_alter_table("file", schema=None) as batch_op:
batch_op.alter_column(
"meta",
type_=sa.JSON(),
existing_type=sa.Text(),
existing_nullable=True,
nullable=True,
postgresql_using="meta::json",
)
# 3. Migrate legacy data from `meta` JSONField
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
# We will use SQLAlchemy core bindings to ensure safety across different databases.
file_table = table(
"file", column("id", String), column("meta", JSON), column("path", Text)
)
# Create connection to the database
connection = op.get_bind()
# Get the rows where `meta` has a path and `path` column is null (new column)
# Loop through each row in the result set to update the path
results = connection.execute(
sa.select(file_table.c.id, file_table.c.meta).where(
and_(file_table.c.path.is_(None), file_table.c.meta.isnot(None))
)
).fetchall()
# Iterate over each row to extract and update the `path` from `meta` column
for row in results:
if "path" in row.meta:
# Extract the `path` field from the `meta` JSON
path = row.meta.get("path")
# Update the `file` table with the new `path` value
connection.execute(
file_table.update()
.where(file_table.c.id == row.id)
.values({"path": path})
)
def downgrade():
# 1. Remove the `path` column
op.drop_column("file", "path")
# 2. Revert the `meta` column back to Text/JSONField
with op.batch_alter_table("file", schema=None) as batch_op:
batch_op.alter_column(
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
)

View file

@ -0,0 +1,50 @@
"""Add folder table
Revision ID: c69f45358db4
Revises: 3ab32c4b8f59
Create Date: 2024-10-16 02:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
revision = "c69f45358db4"
down_revision = "3ab32c4b8f59"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"folder",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("parent_id", sa.Text(), nullable=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("items", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
),
sa.Column(
"updated_at",
sa.DateTime(),
nullable=False,
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
sa.PrimaryKeyConstraint("id", "user_id"),
)
op.add_column(
"chat",
sa.Column("folder_id", sa.Text(), nullable=True),
)
def downgrade():
op.drop_column("chat", "folder_id")
op.drop_table("folder")

View file

@ -0,0 +1,319 @@
/* HTML and Body */
@font-face {
font-family: 'NotoSans';
src: url('fonts/NotoSans-Variable.ttf');
}
@font-face {
font-family: 'NotoSansJP';
src: url('fonts/NotoSansJP-Variable.ttf');
}
@font-face {
font-family: 'NotoSansKR';
src: url('fonts/NotoSansKR-Variable.ttf');
}
@font-face {
font-family: 'NotoSansSC';
src: url('fonts/NotoSansSC-Variable.ttf');
}
@font-face {
font-family: 'NotoSansSC-Regular';
src: url('fonts/NotoSansSC-Regular.ttf');
}
html {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR',
'NotoSansSC', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto,
'Helvetica Neue', Arial, sans-serif;
font-size: 14px; /* Default font size */
line-height: 1.5;
}
*,
*::before,
*::after {
box-sizing: inherit;
}
body {
margin: 0;
color: #212529;
background-color: #fff;
width: auto;
}
/* Typography */
h1,
h2,
h3,
h4,
h5,
h6 {
font-weight: 500;
margin: 0;
}
h1 {
font-size: 2.5rem;
}
h2 {
font-size: 2rem;
}
h3 {
font-size: 1.75rem;
}
h4 {
font-size: 1.5rem;
}
h5 {
font-size: 1.25rem;
}
h6 {
font-size: 1rem;
}
p {
margin-top: 0;
margin-bottom: 1rem;
}
/* Grid System */
.container {
width: 100%;
padding-right: 15px;
padding-left: 15px;
margin-right: auto;
margin-left: auto;
}
/* Utilities */
.text-center {
text-align: center;
}
/* Additional Text Utilities */
.text-muted {
color: #6c757d; /* Muted text color */
}
/* Small Text */
small {
font-size: 80%; /* Smaller font size relative to the base */
color: #6c757d; /* Lighter text color for secondary information */
margin-bottom: 0;
margin-top: 0;
}
/* Strong Element Styles */
strong {
font-weight: bolder; /* Ensures the text is bold */
color: inherit; /* Inherits the color from its parent element */
}
/* link */
a {
color: #007bff;
text-decoration: none;
background-color: transparent;
}
a:hover {
color: #0056b3;
text-decoration: underline;
}
/* General styles for lists */
ol,
ul,
li {
padding-left: 40px; /* Increase padding to move bullet points to the right */
margin-left: 20px; /* Indent lists from the left */
}
/* Ordered list styles */
ol {
list-style-type: decimal; /* Use numbers for ordered lists */
margin-bottom: 10px; /* Space after each list */
}
ol li {
margin-bottom: 0.5rem; /* Space between ordered list items */
}
/* Unordered list styles */
ul {
list-style-type: disc; /* Use bullets for unordered lists */
margin-bottom: 10px; /* Space after each list */
}
ul li {
margin-bottom: 0.5rem; /* Space between unordered list items */
}
/* List item styles */
li {
margin-bottom: 5px; /* Space between list items */
line-height: 1.5; /* Line height for better readability */
}
/* Nested lists */
ol ol,
ol ul,
ul ol,
ul ul {
padding-left: 20px;
margin-left: 30px; /* Further indent nested lists */
margin-bottom: 0; /* Remove extra margin at the bottom of nested lists */
}
/* Code blocks */
pre {
background-color: #f4f4f4;
padding: 10px;
overflow-x: auto;
max-width: 100%; /* Ensure it doesn't overflow the page */
width: 80%; /* Set a specific width for a container-like appearance */
margin: 0 1em; /* Center the pre block */
box-sizing: border-box; /* Include padding in the width */
border: 1px solid #ccc; /* Optional: Add a border for better definition */
border-radius: 4px; /* Optional: Add rounded corners */
}
code {
font-family: 'Courier New', Courier, monospace;
background-color: #f4f4f4;
padding: 2px 4px;
border-radius: 4px;
box-sizing: border-box; /* Include padding in the width */
}
.message {
margin-top: 8px;
margin-bottom: 8px;
max-width: 100%;
overflow-wrap: break-word;
}
/* Table Styles */
table {
width: 100%;
margin-bottom: 1rem;
color: #212529;
border-collapse: collapse; /* Removes the space between borders */
}
th,
td {
margin: 0;
padding: 0.75rem;
vertical-align: top;
border-top: 1px solid #dee2e6;
}
thead th {
vertical-align: bottom;
border-bottom: 2px solid #dee2e6;
}
tbody + tbody {
border-top: 2px solid #dee2e6;
}
/* markdown-section styles */
.markdown-section blockquote,
.markdown-section h1,
.markdown-section h2,
.markdown-section h3,
.markdown-section h4,
.markdown-section h5,
.markdown-section h6,
.markdown-section p,
.markdown-section pre,
.markdown-section table,
.markdown-section ul {
/* Give most block elements margin top and bottom */
margin-top: 1rem;
}
/* Remove top margin if it's the first child */
.markdown-section blockquote:first-child,
.markdown-section h1:first-child,
.markdown-section h2:first-child,
.markdown-section h3:first-child,
.markdown-section h4:first-child,
.markdown-section h5:first-child,
.markdown-section h6:first-child,
.markdown-section p:first-child,
.markdown-section pre:first-child,
.markdown-section table:first-child,
.markdown-section ul:first-child {
margin-top: 0;
}
/* Remove top margin of <ul> following a <p> */
.markdown-section p + ul {
margin-top: 0;
}
/* Remove bottom margin of <p> if it is followed by a <ul> */
/* Note: :has is not supported in CSS, so you would need JavaScript for this behavior */
.markdown-section p {
margin-bottom: 0;
}
/* Add a rule to reset margin-bottom for <p> not followed by <ul> */
.markdown-section p + ul {
margin-top: 0;
}
/* List item styles */
.markdown-section li {
padding: 2px;
}
.markdown-section li p {
margin-bottom: 0;
padding: 0;
}
/* Avoid margins for nested lists */
.markdown-section li > ul {
margin-top: 0;
margin-bottom: 0;
}
/* Table styles */
.markdown-section table {
width: 100%;
border-collapse: collapse;
margin: 1rem 0;
}
.markdown-section th,
.markdown-section td {
border: 1px solid #ddd;
padding: 0.5rem;
text-align: left;
}
.markdown-section th {
background-color: #f2f2f2;
}
.markdown-section pre {
padding: 10px;
margin: 10px;
}
.markdown-section pre code {
position: relative;
color: rgb(172, 0, 95);
}

Binary file not shown.

View file

@ -0,0 +1,163 @@
import os
import boto3
from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
STORAGE_PROVIDER,
S3_ACCESS_KEY_ID,
S3_SECRET_ACCESS_KEY,
S3_BUCKET_NAME,
S3_REGION_NAME,
S3_ENDPOINT_URL,
UPLOAD_DIR,
)
import boto3
from botocore.exceptions import ClientError
from typing import BinaryIO, Tuple, Optional
class StorageProvider:
def __init__(self, provider: Optional[str] = None):
self.storage_provider: str = provider or STORAGE_PROVIDER
self.s3_client = None
self.s3_bucket_name: Optional[str] = None
if self.storage_provider == "s3":
self._initialize_s3()
def _initialize_s3(self) -> None:
"""Initializes the S3 client and bucket name if using S3 storage."""
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def _upload_to_s3(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.upload_fileobj(file, self.bucket_name, filename)
return file.read(), f"s3://{self.bucket_name}/{filename}"
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to local storage."""
file_path = f"{UPLOAD_DIR}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
return contents, file_path
def _get_file_from_s3(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def _get_file_from_local(self, file_path: str) -> str:
"""Handles downloading of the file from local storage."""
return file_path
def _delete_from_s3(self, filename: str) -> None:
"""Handles deletion of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
def _delete_from_local(self, filename: str) -> None:
"""Handles deletion of the file from local storage."""
file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path):
os.remove(file_path)
else:
print(f"File {file_path} not found in local storage.")
def _delete_all_from_s3(self) -> None:
"""Handles deletion of all files from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
def _delete_all_from_local(self) -> None:
"""Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR):
file_path = os.path.join(UPLOAD_DIR, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"Directory {UPLOAD_DIR} not found in local storage.")
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
if self.storage_provider == "s3":
return self._upload_to_s3(file, filename)
return self._upload_to_local(contents, filename)
def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path."""
if self.storage_provider == "s3":
return self._get_file_from_s3(file_path)
return self._get_file_from_local(file_path)
def delete_file(self, filename: str) -> None:
"""Deletes a file either from S3 or the local file system."""
if self.storage_provider == "s3":
self._delete_from_s3(filename)
# Always delete from local storage
self._delete_from_local(filename)
def delete_all_files(self) -> None:
"""Deletes all files from the storage."""
if self.storage_provider == "s3":
self._delete_all_from_s3()
# Always delete from local storage
self._delete_all_from_local()
Storage = StorageProvider(provider=STORAGE_PROVIDER)

View file

@ -1,105 +0,0 @@
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestDocuments(AbstractPostgresTest):
BASE_PATH = "/api/v1/documents"
def setup_class(cls):
super().setup_class()
from open_webui.apps.webui.models.documents import Documents
cls.documents = Documents
def test_documents(self):
# Empty database
assert len(self.documents.get_docs()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a new document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name",
"title": "doc title",
"collection_name": "custom collection",
"filename": "doc_name.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs()) == 1
# Get the document
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
assert response.status_code == 200
data = response.json()
assert data["collection_name"] == "custom collection"
assert data["name"] == "doc_name"
assert data["title"] == "doc title"
assert data["filename"] == "doc_name.pdf"
assert data["content"] == {}
# Create another document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name 2",
"title": "doc title 2",
"collection_name": "custom collection 2",
"filename": "doc_name2.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs()) == 2
# Get all documents
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Update the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/update?name=doc_name"),
json={"name": "doc_name rework", "title": "updated title"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["title"] == "updated title"
# Tag the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/tags"),
json={
"name": "doc_name rework",
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
}
assert len(self.documents.get_docs()) == 2
# Delete the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/doc/delete?name=doc_name rework")
)
assert response.status_code == 200
assert len(self.documents.get_docs()) == 1

View file

@ -0,0 +1,261 @@
import base64
import logging
import mimetypes
import uuid
import aiohttp
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import (
HTTPException,
status,
)
from starlette.responses import RedirectResponse
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.users import Users
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
WEBHOOK_URL,
JWT_EXPIRES_IN,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
from open_webui.utils.misc import parse_duration
from open_webui.utils.utils import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__)
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
self.oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
return "admin"
if not user and Users.get_num_users() == 0:
# If there are no users, assign the role "admin", as the first user will be an admin
return "admin"
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found
# Next block extracts the roles from the user data, accepting nested claims of any depth
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
role = "admin"
break
else:
if not user:
# If role management is disabled, use the default role for new users
role = auth_manager_config.DEFAULT_USER_ROLE
else:
# If role management is disabled, use the existing role for existing users
role = user.role
return role
async def handle_login(self, provider, request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = self.get_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, provider, request, response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = self.get_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
if not user_data:
user_data: UserInfo = await client.userinfo(token=token)
if not user_data:
log.warning(f"OAuth callback failed, user data is missing: {token}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(
user_data.get("email", "").lower()
)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(
f"Error downloading profile image '{picture_url}': {e}"
)
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if auth_manager_config.WEBHOOK_URL:
post_webhook(
auth_manager_config.WEBHOOK_URL,
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(
user.name
),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
oauth_manager = OAuthManager()

View file

@ -88,6 +88,53 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
return form_data
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
ollama_messages = []
for message in messages:
# Initialize the new message structure with the role
new_message = {"role": message["role"]}
content = message.get("content", [])
# Check if the content is a string (just a simple message)
if isinstance(content, str):
# If the content is a string, it's pure text
new_message["content"] = content
else:
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
content_text = ""
images = []
# Iterate through the list of content items
for item in content:
# Check if it's a text type
if item.get("type") == "text":
content_text += item.get("text", "")
# Check if it's an image URL type
elif item.get("type") == "image_url":
img_url = item.get("image_url", {}).get("url", "")
if img_url:
# If the image url starts with data:, it's a base64 image and should be trimmed
if img_url.startswith("data:"):
img_url = img_url.split(",")[-1]
images.append(img_url)
# Add content text (if any)
if content_text:
new_message["content"] = content_text.strip()
# Add images (if any)
if images:
new_message["images"] = images
# Append the new formatted message to the result
ollama_messages.append(new_message)
return ollama_messages
def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
"""
Converts a payload formatted for OpenAI's API to be compatible with Ollama's API endpoint for chat completions.
@ -102,7 +149,9 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
# Mapping basic model and message details
ollama_payload["model"] = openai_payload.get("model")
ollama_payload["messages"] = openai_payload.get("messages")
ollama_payload["messages"] = convert_messages_openai_to_ollama(
openai_payload.get("messages")
)
ollama_payload["stream"] = openai_payload.get("stream", False)
# If there are advanced parameters in the payload, format them in Ollama's options field

View file

@ -0,0 +1,139 @@
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Dict, Any, List
from markdown import markdown
import site
from fpdf import FPDF
from open_webui.env import STATIC_DIR, FONTS_DIR
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
class PDFGenerator:
"""
Description:
The `PDFGenerator` class is designed to create PDF documents from chat messages.
The process involves transforming markdown content into HTML and then into a PDF format
Attributes:
- `form_data`: An instance of `ChatTitleMessagesForm` containing title and messages.
"""
def __init__(self, form_data: ChatTitleMessagesForm):
self.html_body = None
self.messages_html = None
self.form_data = form_data
self.css = Path(STATIC_DIR / "assets" / "pdf-style.css").read_text()
def format_timestamp(self, timestamp: float) -> str:
"""Convert a UNIX timestamp to a formatted date string."""
try:
date_time = datetime.fromtimestamp(timestamp)
return date_time.strftime("%Y-%m-%d, %H:%M:%S")
except (ValueError, TypeError) as e:
# Log the error if necessary
return ""
def _build_html_message(self, message: Dict[str, Any]) -> str:
"""Build HTML for a single message."""
role = message.get("role", "user")
content = message.get("content", "")
timestamp = message.get("timestamp")
model = message.get("model") if role == "assistant" else ""
date_str = self.format_timestamp(timestamp) if timestamp else ""
# extends pymdownx extension to convert markdown to html.
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
html_content = markdown(content, extensions=["pymdownx.extra"])
html_message = f"""
<div class="message">
<small> {date_str} </small>
<div>
<h2>
<strong>{role.title()}</strong>
<small class="text-muted">{model}</small>
</h2>
</div>
<div class="markdown-section">
{html_content}
</div>
</div>
"""
return html_message
def _generate_html_body(self) -> str:
"""Generate the full HTML body for the PDF."""
return f"""
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body>
<div class="container">
<div class="text-center">
<h1>{self.form_data.title}</h1>
</div>
<div>
{self.messages_html}
</div>
</div>
</body>
</html>
"""
def generate_chat_pdf(self) -> bytes:
"""
Generate a PDF from chat messages.
"""
try:
global FONTS_DIR
pdf = FPDF()
pdf.add_page()
# When running using `pip install` the static directory is in the site packages.
if not FONTS_DIR.exists():
FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts"
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf")
pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf")
pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf")
pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf")
pdf.set_font("NotoSans", size=12)
pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"])
pdf.set_auto_page_break(auto=True, margin=15)
# Build HTML messages
messages_html_list: List[str] = [
self._build_html_message(msg) for msg in self.form_data.messages
]
self.messages_html = "<div>" + "".join(messages_html_list) + "</div>"
# Generate full HTML body
self.html_body = self._generate_html_body()
pdf.write_html(self.html_body)
# Save the pdf with name .pdf
pdf_bytes = pdf.output()
return bytes(pdf_bytes)
except Exception as e:
raise e

View file

@ -26,7 +26,6 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
)
line = f"data: {json.dumps(data)}\n\n"
if done:
line += "data: [DONE]\n\n"
yield line
yield "data: [DONE]\n\n"

View file

@ -70,22 +70,6 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
template = replace_prompt_variable(template, prompt)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def replace_messages_variable(template: str, messages: list[str]) -> str:
def replacement_function(match):
full_match = match.group(0)
@ -123,6 +107,62 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
return template
# {{prompt:middletruncate:8000}}
def title_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def tags_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def emoji_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
template = replace_prompt_variable(template, prompt)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def search_query_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:

View file

@ -7,7 +7,7 @@ import jwt
from open_webui.apps.webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY
from fastapi import Depends, HTTPException, Request, status
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext

View file

@ -12,6 +12,7 @@ passlib[bcrypt]==1.7.4
requests==2.32.3
aiohttp==3.10.8
async-timeout
sqlalchemy==2.0.32
alembic==1.13.2
@ -41,14 +42,17 @@ langchain-chroma==0.1.4
fake-useragent==1.5.1
chromadb==0.5.9
pymilvus==2.4.7
qdrant-client~=1.12.0
sentence-transformers==3.0.1
sentence-transformers==3.2.0
colbert-ai==0.2.21
einops==0.8.0
ftfy==6.2.3
pypdf==4.3.1
xhtml2pdf==0.2.16
pymdown-extensions==10.11.2
docx2txt==0.8
python-pptx==1.0.0
unstructured==0.15.9
@ -86,3 +90,5 @@ duckduckgo-search~=6.2.13
docker~=7.1.0
pytest~=8.3.2
pytest-docker~=3.1.1
googleapis-common-protos==1.63.2

View file

@ -30,7 +30,7 @@ describe('Settings', () => {
// Select the first model
cy.get('button[aria-label="model-item"]').first().click();
// Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', {
cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true
});
// Send the message
@ -50,7 +50,7 @@ describe('Settings', () => {
// Select the first model
cy.get('button[aria-label="model-item"]').first().click();
// Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', {
cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true
});
// Send the message
@ -85,7 +85,7 @@ describe('Settings', () => {
// Select the first model
cy.get('button[aria-label="model-item"]').first().click();
// Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', {
cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true
});
// Send the message

View file

@ -1,46 +1,2 @@
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
describe('Documents', () => {
const timestamp = Date.now();
before(() => {
cy.uploadTestDocument(timestamp);
});
after(() => {
cy.deleteTestDocument(timestamp);
});
context('Admin', () => {
beforeEach(() => {
// Login as the admin user
cy.loginAdmin();
// Visit the home page
cy.visit('/workspace/documents');
cy.get('button').contains('#cypress-test').click();
});
it('can see documents', () => {
cy.get('div').contains(`document-test-initial-${timestamp}.txt`).should('have.length', 1);
});
it('can see edit button', () => {
cy.get('div')
.contains(`document-test-initial-${timestamp}.txt`)
.get("button[aria-label='Edit Doc']")
.should('exist');
});
it('can see delete button', () => {
cy.get('div')
.contains(`document-test-initial-${timestamp}.txt`)
.get("button[aria-label='Delete Doc']")
.should('exist');
});
it('can see upload button', () => {
cy.get("button[aria-label='Add Docs']").should('exist');
});
});
});

View file

@ -73,50 +73,6 @@ Cypress.Commands.add('register', (name, email, password) => register(name, email
Cypress.Commands.add('registerAdmin', () => registerAdmin());
Cypress.Commands.add('loginAdmin', () => loginAdmin());
Cypress.Commands.add('uploadTestDocument', (suffix: any) => {
// Login as admin
cy.loginAdmin();
// upload example document
cy.visit('/workspace/documents');
// Create a document
cy.get("button[aria-label='Add Docs']").click();
cy.readFile('cypress/data/example-doc.txt').then((text) => {
// select file
cy.get('#upload-doc-input').selectFile(
{
contents: Cypress.Buffer.from(text + Date.now()),
fileName: `document-test-initial-${suffix}.txt`,
mimeType: 'text/plain',
lastModified: Date.now()
},
{
force: true
}
);
// open tag input
cy.get("button[aria-label='Add Tag']").click();
cy.get("input[placeholder='Add a tag']").type('cypress-test');
cy.get("button[aria-label='Save Tag']").click();
// submit to upload
cy.get("button[type='submit']").click();
// wait for upload to finish
cy.get('button').contains('#cypress-test').should('exist');
cy.get('div').contains(`document-test-initial-${suffix}.txt`).should('exist');
});
});
Cypress.Commands.add('deleteTestDocument', (suffix: any) => {
cy.loginAdmin();
cy.visit('/workspace/documents');
// clean up uploaded documents
cy.get('div')
.contains(`document-test-initial-${suffix}.txt`)
.find("button[aria-label='Delete Doc']")
.click();
});
before(() => {
cy.registerAdmin();
});

View file

@ -118,7 +118,7 @@ Navigate to the apache sites-available directory:
`nano models.server.city.conf` # match this with your ollama server domain
Add the folloing virtualhost containing this example (modify as needed):
Add the following virtualhost containing this example (modify as needed):
```

1040
package-lock.json generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
{
"name": "open-webui",
"version": "0.3.32",
"version": "0.3.33",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",
@ -52,6 +52,7 @@
"@codemirror/lang-python": "^6.1.6",
"@codemirror/language-data": "^6.5.1",
"@codemirror/theme-one-dark": "^6.1.2",
"@huggingface/transformers": "^3.0.0",
"@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^2.0.0",
"@xyflow/svelte": "^0.1.19",
@ -72,9 +73,19 @@
"js-sha256": "^0.10.1",
"katex": "^0.16.9",
"marked": "^9.1.0",
"mermaid": "^10.9.1",
"mermaid": "^10.9.3",
"paneforge": "^0.0.6",
"panzoom": "^9.4.3",
"prosemirror-commands": "^1.6.0",
"prosemirror-example-setup": "^1.2.3",
"prosemirror-history": "^1.4.1",
"prosemirror-keymap": "^1.2.2",
"prosemirror-markdown": "^1.13.1",
"prosemirror-model": "^1.23.0",
"prosemirror-schema-basic": "^1.2.3",
"prosemirror-schema-list": "^1.4.1",
"prosemirror-state": "^1.4.3",
"prosemirror-view": "^1.34.3",
"pyodide": "^0.26.1",
"socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2",

View file

@ -20,6 +20,7 @@ dependencies = [
"requests==2.32.3",
"aiohttp==3.10.8",
"async-timeout",
"sqlalchemy==2.0.32",
"alembic==1.13.2",
@ -49,13 +50,14 @@ dependencies = [
"chromadb==0.5.9",
"pymilvus==2.4.7",
"sentence-transformers==3.0.1",
"sentence-transformers==3.2.0",
"colbert-ai==0.2.21",
"einops==0.8.0",
"ftfy==6.2.3",
"pypdf==4.3.1",
"xhtml2pdf==0.2.16",
"pymdown-extensions==10.11.2",
"docx2txt==0.8",
"python-pptx==1.0.0",
"unstructured==0.15.9",
@ -91,7 +93,9 @@ dependencies = [
"docker~=7.1.0",
"pytest~=8.3.2",
"pytest-docker~=3.1.1"
"pytest-docker~=3.1.1",
"googleapis-common-protos==1.63.2"
]
readme = "README.md"
requires-python = ">= 3.11, < 3.12.0a1"

View file

@ -34,6 +34,14 @@ math {
@apply rounded-lg;
}
.input-prose {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
}
.input-prose-sm {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line text-sm;
}
.markdown-prose {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
}
@ -56,7 +64,7 @@ li p {
::-webkit-scrollbar-thumb {
--tw-border-opacity: 1;
background-color: rgba(217, 217, 227, 0.8);
background-color: rgba(236, 236, 236, 0.8);
border-color: rgba(255, 255, 255, var(--tw-border-opacity));
border-radius: 9999px;
border-width: 1px;
@ -64,7 +72,7 @@ li p {
/* Dark theme scrollbar styles */
.dark ::-webkit-scrollbar-thumb {
background-color: rgba(69, 69, 74, 0.8); /* Darker color for dark theme */
background-color: rgba(33, 33, 33, 0.8); /* Darker color for dark theme */
border-color: rgba(0, 0, 0, var(--tw-border-opacity));
}
@ -179,3 +187,21 @@ input[type='number'] {
.bg-gray-950-90 {
background-color: rgba(var(--color-gray-950, #0d0d0d), 0.9);
}
.ProseMirror {
@apply h-full min-h-fit max-h-full;
}
.ProseMirror:focus {
outline: none;
}
.placeholder::after {
content: attr(data-placeholder);
cursor: text;
pointer-events: none;
float: left;
@apply absolute inset-0 z-0 text-gray-500;
}

View file

@ -180,6 +180,31 @@ export const userSignUp = async (
return res;
};
export const userSignOut = async () => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/signout`, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
},
credentials: 'include'
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res;
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
};
export const addUser = async (
token: string,
name: string,

View file

@ -32,6 +32,46 @@ export const createNewChat = async (token: string, chat: object) => {
return res;
};
export const importChat = async (
token: string,
chat: object,
meta: object | null,
pinned?: boolean,
folderId?: string | null
) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/import`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
chat: chat,
meta: meta ?? {},
pinned: pinned,
folder_id: folderId
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getChatList = async (token: string = '', page: number | null = null) => {
let error = null;
const searchParams = new URLSearchParams();
@ -167,6 +207,75 @@ export const getAllChats = async (token: string) => {
return res;
};
export const getChatListBySearchText = async (token: string, text: string, page: number = 1) => {
let error = null;
const searchParams = new URLSearchParams();
searchParams.append('text', text);
searchParams.append('page', `${page}`);
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/search?${searchParams.toString()}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res.map((chat) => ({
...chat,
time_range: getTimeRange(chat.updated_at)
}));
};
export const getChatsByFolderId = async (token: string, folderId: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/folder/${folderId}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getAllArchivedChats = async (token: string) => {
let error = null;
@ -229,10 +338,10 @@ export const getAllUserChats = async (token: string) => {
return res;
};
export const getAllChatTags = async (token: string) => {
export const getAllTags = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, {
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, {
method: 'GET',
headers: {
Accept: 'application/json',
@ -260,6 +369,40 @@ export const getAllChatTags = async (token: string) => {
return res;
};
export const getPinnedChatList = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/pinned`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res.map((chat) => ({
...chat,
time_range: getTimeRange(chat.updated_at)
}));
};
export const getChatListByTagName = async (token: string = '', tagName: string) => {
let error = null;
@ -361,11 +504,87 @@ export const getChatByShareId = async (token: string, share_id: string) => {
return res;
};
export const getChatPinnedStatusById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pinned`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const toggleChatPinnedStatusById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pin`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const cloneChatById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, {
method: 'GET',
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
@ -431,11 +650,46 @@ export const shareChatById = async (token: string, id: string) => {
return res;
};
export const updateChatFolderIdById = async (token: string, id: string, folderId?: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/folder`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
folder_id: folderId
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const archiveChatById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, {
method: 'GET',
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
@ -605,8 +859,7 @@ export const addTagById = async (token: string, id: string, tagName: string) =>
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
tag_name: tagName,
chat_id: id
name: tagName
})
})
.then(async (res) => {
@ -617,8 +870,7 @@ export const addTagById = async (token: string, id: string, tagName: string) =>
return json;
})
.catch((err) => {
error = err;
error = err.detail;
console.log(err);
return null;
});
@ -641,8 +893,7 @@ export const deleteTagById = async (token: string, id: string, tagName: string)
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
tag_name: tagName,
chat_id: id
name: tagName
})
})
.then(async (res) => {

View file

@ -1,51 +1,9 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewDoc = async (
token: string,
collection_name: string,
filename: string,
name: string,
title: string,
content: object | null = null
) => {
export const getConfig = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/documents/create`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
collection_name: collection_name,
filename: filename,
name: name,
title: title,
...(content ? { content: JSON.stringify(content) } : {})
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getDocs = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/documents/`, {
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/config`, {
method: 'GET',
headers: {
Accept: 'application/json',
@ -73,13 +31,41 @@ export const getDocs = async (token: string = '') => {
return res;
};
export const getDocByName = async (token: string, name: string) => {
export const updateConfig = async (token: string, config: object) => {
let error = null;
const searchParams = new URLSearchParams();
searchParams.append('name', name);
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/config`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...config
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
const res = await fetch(`${WEBUI_API_BASE_URL}/documents/docs?${searchParams.toString()}`, {
if (error) {
throw error;
}
return res;
};
export const getAllFeedbacks = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedbacks/all`, {
method: 'GET',
headers: {
Accept: 'application/json',
@ -96,7 +82,6 @@ export const getDocByName = async (token: string, name: string) => {
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
@ -108,18 +93,10 @@ export const getDocByName = async (token: string, name: string) => {
return res;
};
type DocUpdateForm = {
name: string;
title: string;
};
export const updateDocByName = async (token: string, name: string, form: DocUpdateForm) => {
export const createNewFeedback = async (token: string, feedback: object) => {
let error = null;
const searchParams = new URLSearchParams();
searchParams.append('name', name);
const res = await fetch(`${WEBUI_API_BASE_URL}/documents/doc/update?${searchParams.toString()}`, {
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback`, {
method: 'POST',
headers: {
Accept: 'application/json',
@ -127,9 +104,36 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda
authorization: `Bearer ${token}`
},
body: JSON.stringify({
name: form.name,
title: form.title
...feedback
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFeedbackById = async (token: string, feedbackId: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
@ -140,7 +144,6 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
@ -152,18 +155,10 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda
return res;
};
type TagDocForm = {
name: string;
tags: string[];
};
export const tagDocByName = async (token: string, name: string, form: TagDocForm) => {
export const updateFeedbackById = async (token: string, feedbackId: string, feedback: object) => {
let error = null;
const searchParams = new URLSearchParams();
searchParams.append('name', name);
const res = await fetch(`${WEBUI_API_BASE_URL}/documents/doc/tags?${searchParams.toString()}`, {
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, {
method: 'POST',
headers: {
Accept: 'application/json',
@ -171,20 +166,15 @@ export const tagDocByName = async (token: string, name: string, form: TagDocForm
authorization: `Bearer ${token}`
},
body: JSON.stringify({
name: form.name,
tags: form.tags
...feedback
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
@ -196,13 +186,10 @@ export const tagDocByName = async (token: string, name: string, form: TagDocForm
return res;
};
export const deleteDocByName = async (token: string, name: string) => {
export const deleteFeedbackById = async (token: string, feedbackId: string) => {
let error = null;
const searchParams = new URLSearchParams();
searchParams.append('name', name);
const res = await fetch(`${WEBUI_API_BASE_URL}/documents/doc/delete?${searchParams.toString()}`, {
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
@ -214,12 +201,8 @@ export const deleteDocByName = async (token: string, name: string) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});

View file

@ -0,0 +1,269 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewFolder = async (token: string, name: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
name: name
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFolders = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFolderById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFolderNameById = async (token: string, id: string, name: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
name: name
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFolderIsExpandedById = async (
token: string,
id: string,
isExpanded: boolean
) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update/expanded`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
is_expanded: isExpanded
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFolderParentIdById = async (token: string, id: string, parentId?: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update/parent`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
parent_id: parentId
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
type FolderItems = {
chat_ids: string[];
file_ids: string[];
};
export const updateFolderItemsById = async (token: string, id: string, items: FolderItems) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update/items`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
items: items
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFolderById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -208,7 +208,7 @@ export const updateTaskConfig = async (token: string, config: object) => {
export const generateTitle = async (
token: string = '',
model: string,
prompt: string,
messages: string[],
chat_id?: string
) => {
let error = null;
@ -222,7 +222,7 @@ export const generateTitle = async (
},
body: JSON.stringify({
model: model,
prompt: prompt,
messages: messages,
...(chat_id && { chat_id: chat_id })
})
})
@ -245,6 +245,78 @@ export const generateTitle = async (
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
};
export const generateTags = async (
token: string = '',
model: string,
messages: string,
chat_id?: string
) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
messages: messages,
...(chat_id && { chat_id: chat_id })
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
try {
// Step 1: Safely extract the response string
const response = res?.choices[0]?.message?.content ?? '';
// Step 2: Attempt to fix common JSON format issues like single quotes
const sanitizedResponse = response.replace(/['`]/g, '"'); // Convert single quotes to double quotes for valid JSON
// Step 3: Find the relevant JSON block within the response
const jsonStartIndex = sanitizedResponse.indexOf('{');
const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
// Step 5: Parse the JSON block
const parsed = JSON.parse(jsonResponse);
// Step 6: If there's a "tags" key, return the tags array; otherwise, return an empty array
if (parsed && parsed.tags) {
return Array.isArray(parsed.tags) ? parsed.tags : [];
} else {
return [];
}
}
// If no valid JSON block found, return an empty array
return [];
} catch (e) {
// Catch and safely return empty array on any parsing errors
console.error('Failed to parse response: ', e);
return [];
}
};
export const generateEmoji = async (
token: string = '',
model: string,

View file

@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => {
type OpenAIConfigForm = {
key: string;
url: string;
batch_size: number;
};
type EmbeddingModelUpdateForm = {
openai_config?: OpenAIConfigForm;
embedding_engine: string;
embedding_model: string;
embedding_batch_size?: number;
};
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {

View file

@ -7,6 +7,7 @@ type TextStreamUpdate = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
citations?: any;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
selectedModelId?: any;
error?: any;
usage?: ResponseUsage;
};
@ -71,6 +72,11 @@ async function* openAIStreamToIterator(
continue;
}
if (parsedData.selected_model_id) {
yield { done: false, value: '', selectedModelId: parsedData.selected_model_id };
continue;
}
yield {
done: false,
value: parsedData.choices?.[0]?.delta?.content ?? '',

View file

@ -2,20 +2,27 @@
import { getContext } from 'svelte';
export let title = '';
export let content = '';
const i18n = getContext('i18n');
</script>
<div class=" text-center text-6xl mb-3">📄</div>
<div class="text-center dark:text-white text-2xl font-semibold z-50">
{#if title}
{title}
{:else}
{$i18n.t('Add Files')}
{/if}
</div>
<slot
><div class=" mt-2 text-center text-sm dark:text-gray-200 w-full">
{$i18n.t('Drop any files here to add to the conversation')}
<div class="px-3">
<div class="text-center text-6xl mb-3">📄</div>
<div class="text-center dark:text-white text-xl font-semibold z-50">
{#if title}
{title}
{:else}
{$i18n.t('Add Files')}
{/if}
</div>
</slot>
<slot
><div class="px-2 mt-2 text-center text-sm dark:text-gray-200 w-full">
{#if content}
{content}
{:else}
{$i18n.t('Drop any files here to add to the conversation')}
{/if}
</div>
</slot>
</div>

View file

@ -2,12 +2,13 @@
import { onMount, getContext } from 'svelte';
import { Confetti } from 'svelte-confetti';
import { WEBUI_NAME, config } from '$lib/stores';
import { WEBUI_NAME, config, settings } from '$lib/stores';
import { WEBUI_VERSION } from '$lib/constants';
import { getChangelog } from '$lib/apis';
import Modal from './common/Modal.svelte';
import { updateUserSettings } from '$lib/apis/users';
const i18n = getContext('i18n');
@ -104,8 +105,10 @@
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
on:click={() => {
on:click={async () => {
localStorage.version = $config.version;
await settings.set({ ...$settings, ...{ version: $config.version } });
await updateUserSettings(localStorage.token, { ui: $settings });
show = false;
}}
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"

View file

@ -139,7 +139,7 @@
</button>
</div>
<div class="flex flex-col md:flex-row w-full px-5 pb-4 md:space-x-4 dark:text-gray-200">
<div class="flex flex-col md:flex-row w-full px-4 pb-3 md:space-x-4 dark:text-gray-200">
<div class=" flex flex-col w-full sm:flex-row sm:justify-center sm:space-x-6">
<form
class="flex flex-col w-full"
@ -147,9 +147,9 @@
submitHandler();
}}
>
<div class="flex text-center text-sm font-medium rounded-xl bg-transparent/10 p-1 mb-2">
<div class="flex text-center text-sm font-medium rounded-full bg-transparent/10 p-1 mb-2">
<button
class="w-full rounded-lg p-1.5 {tab === '' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
class="w-full rounded-full p-1.5 {tab === '' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
type="button"
on:click={() => {
tab = '';
@ -157,7 +157,9 @@
>
<button
class="w-full rounded-lg p-1 {tab === 'import' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
class="w-full rounded-full p-1 {tab === 'import'
? 'bg-gray-50 dark:bg-gray-850'
: ''}"
type="button"
on:click={() => {
tab = 'import';
@ -183,7 +185,7 @@
</div>
</div>
<div class="flex flex-col w-full mt-2">
<div class="flex flex-col w-full mt-1">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Name')}</div>
<div class="flex-1">
@ -198,7 +200,7 @@
</div>
</div>
<hr class=" dark:border-gray-800 my-3 w-full" />
<hr class=" border-gray-50 dark:border-gray-850 my-2.5 w-full" />
<div class="flex flex-col w-full">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Email')}</div>
@ -209,13 +211,12 @@
type="email"
bind:value={_user.email}
placeholder={$i18n.t('Enter Your Email')}
autocomplete="off"
required
/>
</div>
</div>
<div class="flex flex-col w-full mt-2">
<div class="flex flex-col w-full mt-1">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div>
<div class="flex-1">
@ -271,13 +272,13 @@
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed'
: ''}"
type="submit"
disabled={loading}
>
{$i18n.t('Submit')}
{$i18n.t('Save')}
{#if loading}
<div class="ml-2 self-center">

View file

@ -0,0 +1,633 @@
<script lang="ts">
import { onMount, getContext } from 'svelte';
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
dayjs.extend(relativeTime);
import * as ort from 'onnxruntime-web';
import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
const EMBEDDING_MODEL = 'TaylorAI/bge-micro-v2';
let tokenizer = null;
let model = null;
import { models } from '$lib/stores';
import { deleteFeedbackById, getAllFeedbacks } from '$lib/apis/evaluations';
import FeedbackMenu from './Evaluations/FeedbackMenu.svelte';
import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte';
import Tooltip from '../common/Tooltip.svelte';
import Badge from '../common/Badge.svelte';
import Pagination from '../common/Pagination.svelte';
import MagnifyingGlass from '../icons/MagnifyingGlass.svelte';
import Share from '../icons/Share.svelte';
import CloudArrowUp from '../icons/CloudArrowUp.svelte';
import { toast } from 'svelte-sonner';
const i18n = getContext('i18n');
let rankedModels = [];
let feedbacks = [];
let query = '';
let page = 1;
let tagEmbeddings = new Map();
let loaded = false;
let debounceTimer;
$: paginatedFeedbacks = feedbacks.slice((page - 1) * 10, page * 10);
type Feedback = {
id: string;
data: {
rating: number;
model_id: string;
sibling_model_ids: string[] | null;
reason: string;
comment: string;
tags: string[];
};
user: {
name: string;
profile_image_url: string;
};
updated_at: number;
};
type ModelStats = {
rating: number;
won: number;
lost: number;
};
//////////////////////
//
// Rank models by Elo rating
//
//////////////////////
const rankHandler = async (similarities: Map<string, number> = new Map()) => {
const modelStats = calculateModelStats(feedbacks, similarities);
rankedModels = $models
.filter((m) => m?.owned_by !== 'arena' && (m?.info?.meta?.hidden ?? false) !== true)
.map((model) => {
const stats = modelStats.get(model.id);
return {
...model,
rating: stats ? Math.round(stats.rating) : '-',
stats: {
count: stats ? stats.won + stats.lost : 0,
won: stats ? stats.won.toString() : '-',
lost: stats ? stats.lost.toString() : '-'
}
};
})
.sort((a, b) => {
if (a.rating === '-' && b.rating !== '-') return 1;
if (b.rating === '-' && a.rating !== '-') return -1;
if (a.rating !== '-' && b.rating !== '-') return b.rating - a.rating;
return a.name.localeCompare(b.name);
});
};
function calculateModelStats(
feedbacks: Feedback[],
similarities: Map<string, number>
): Map<string, ModelStats> {
const stats = new Map<string, ModelStats>();
const K = 32;
function getOrDefaultStats(modelId: string): ModelStats {
return stats.get(modelId) || { rating: 1000, won: 0, lost: 0 };
}
function updateStats(modelId: string, ratingChange: number, outcome: number) {
const currentStats = getOrDefaultStats(modelId);
currentStats.rating += ratingChange;
if (outcome === 1) currentStats.won++;
else if (outcome === 0) currentStats.lost++;
stats.set(modelId, currentStats);
}
function calculateEloChange(
ratingA: number,
ratingB: number,
outcome: number,
similarity: number
): number {
const expectedScore = 1 / (1 + Math.pow(10, (ratingB - ratingA) / 400));
return K * (outcome - expectedScore) * similarity;
}
feedbacks.forEach((feedback) => {
const modelA = feedback.data.model_id;
const statsA = getOrDefaultStats(modelA);
let outcome: number;
switch (feedback.data.rating.toString()) {
case '1':
outcome = 1;
break;
case '-1':
outcome = 0;
break;
default:
return; // Skip invalid ratings
}
// If the query is empty, set similarity to 1, else get the similarity from the map
const similarity = query !== '' ? similarities.get(feedback.id) || 0 : 1;
const opponents = feedback.data.sibling_model_ids || [];
opponents.forEach((modelB) => {
const statsB = getOrDefaultStats(modelB);
const changeA = calculateEloChange(statsA.rating, statsB.rating, outcome, similarity);
const changeB = calculateEloChange(statsB.rating, statsA.rating, 1 - outcome, similarity);
updateStats(modelA, changeA, outcome);
updateStats(modelB, changeB, 1 - outcome);
});
});
return stats;
}
//////////////////////
//
// Calculate cosine similarity
//
//////////////////////
const cosineSimilarity = (vecA, vecB) => {
// Ensure the lengths of the vectors are the same
if (vecA.length !== vecB.length) {
throw new Error('Vectors must be the same length');
}
// Calculate the dot product
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < vecA.length; i++) {
dotProduct += vecA[i] * vecB[i];
normA += vecA[i] ** 2;
normB += vecB[i] ** 2;
}
// Calculate the magnitudes
normA = Math.sqrt(normA);
normB = Math.sqrt(normB);
// Avoid division by zero
if (normA === 0 || normB === 0) {
return 0;
}
// Return the cosine similarity
return dotProduct / (normA * normB);
};
const calculateMaxSimilarity = (queryEmbedding, tagEmbeddings: Map<string, number[]>) => {
let maxSimilarity = 0;
for (const tagEmbedding of tagEmbeddings.values()) {
const similarity = cosineSimilarity(queryEmbedding, tagEmbedding);
maxSimilarity = Math.max(maxSimilarity, similarity);
}
return maxSimilarity;
};
//////////////////////
//
// Embedding functions
//
//////////////////////
const getEmbeddings = async (text: string) => {
const tokens = await tokenizer(text);
const output = await model(tokens);
// Perform mean pooling on the last hidden states
const embeddings = output.last_hidden_state.mean(1);
return embeddings.ort_tensor.data;
};
const getTagEmbeddings = async (tags: string[]) => {
const embeddings = new Map();
for (const tag of tags) {
if (!tagEmbeddings.has(tag)) {
tagEmbeddings.set(tag, await getEmbeddings(tag));
}
embeddings.set(tag, tagEmbeddings.get(tag));
}
return embeddings;
};
const debouncedQueryHandler = async () => {
if (query.trim() === '') {
rankHandler();
return;
}
clearTimeout(debounceTimer);
debounceTimer = setTimeout(async () => {
const queryEmbedding = await getEmbeddings(query);
const similarities = new Map<string, number>();
for (const feedback of feedbacks) {
const feedbackTags = feedback.data.tags || [];
const tagEmbeddings = await getTagEmbeddings(feedbackTags);
const maxSimilarity = calculateMaxSimilarity(queryEmbedding, tagEmbeddings);
similarities.set(feedback.id, maxSimilarity);
}
rankHandler(similarities);
}, 1500); // Debounce for 1.5 seconds
};
$: query, debouncedQueryHandler();
//////////////////////
//
// CRUD operations
//
//////////////////////
const deleteFeedbackHandler = async (feedbackId: string) => {
const response = await deleteFeedbackById(localStorage.token, feedbackId).catch((err) => {
toast.error(err);
return null;
});
if (response) {
feedbacks = feedbacks.filter((f) => f.id !== feedbackId);
}
};
const shareHandler = async () => {
toast.success($i18n.t('Redirecting you to OpenWebUI Community'));
// remove snapshot from feedbacks
const feedbacksToShare = feedbacks.map((f) => {
const { snapshot, user, ...rest } = f;
return rest;
});
console.log(feedbacksToShare);
const url = 'https://openwebui.com';
const tab = await window.open(`${url}/leaderboard`, '_blank');
// Define the event handler function
const messageHandler = (event) => {
if (event.origin !== url) return;
if (event.data === 'loaded') {
tab.postMessage(JSON.stringify(feedbacksToShare), '*');
// Remove the event listener after handling the message
window.removeEventListener('message', messageHandler);
}
};
window.addEventListener('message', messageHandler, false);
};
onMount(async () => {
feedbacks = await getAllFeedbacks(localStorage.token);
loaded = true;
// Check if the tokenizer and model are already loaded and stored in the window object
if (!window.tokenizer) {
window.tokenizer = await AutoTokenizer.from_pretrained(EMBEDDING_MODEL);
}
if (!window.model) {
window.model = await AutoModel.from_pretrained(EMBEDDING_MODEL);
}
// Use the tokenizer and model from the window object
tokenizer = window.tokenizer;
model = window.model;
// Pre-compute embeddings for all unique tags
const allTags = new Set(feedbacks.flatMap((feedback) => feedback.data.tags || []));
await getTagEmbeddings(Array.from(allTags));
rankHandler();
});
</script>
{#if loaded}
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
<div class="flex md:self-center text-lg font-medium px-0.5 shrink-0 items-center">
<div class=" gap-1">
{$i18n.t('Leaderboard')}
</div>
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" />
<span class="text-lg font-medium text-gray-500 dark:text-gray-300 mr-1.5"
>{rankedModels.length}</span
>
</div>
<div class=" flex space-x-2">
<Tooltip content={$i18n.t('Re-rank models by topic similarity')}>
<div class="flex flex-1">
<div class=" self-center ml-1 mr-3">
<MagnifyingGlass className="size-3" />
</div>
<input
class=" w-full text-sm pr-4 py-1 rounded-r-xl outline-none bg-transparent"
bind:value={query}
placeholder={$i18n.t('Search')}
/>
</div>
</Tooltip>
</div>
</div>
<div
class="scrollbar-hidden relative whitespace-nowrap overflow-x-auto max-w-full rounded pt-0.5"
>
{#if (rankedModels ?? []).length === 0}
<div class="text-center text-xs text-gray-500 dark:text-gray-400 py-1">
{$i18n.t('No models found')}
</div>
{:else}
<table
class="w-full text-sm text-left text-gray-500 dark:text-gray-400 table-auto max-w-full rounded"
>
<thead
class="text-xs text-gray-700 uppercase bg-gray-50 dark:bg-gray-850 dark:text-gray-400 -translate-y-0.5"
>
<tr class="">
<th scope="col" class="px-3 py-1.5 cursor-pointer select-none w-3">
{$i18n.t('RK')}
</th>
<th scope="col" class="px-3 py-1.5 cursor-pointer select-none">
{$i18n.t('Model')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-fit">
{$i18n.t('Rating')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-5">
{$i18n.t('Won')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-5">
{$i18n.t('Lost')}
</th>
</tr>
</thead>
<tbody class="">
{#each rankedModels as model, modelIdx (model.id)}
<tr class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs group">
<td class="px-3 py-1.5 text-left font-medium text-gray-900 dark:text-white w-fit">
<div class=" line-clamp-1">
{model?.rating !== '-' ? modelIdx + 1 : '-'}
</div>
</td>
<td class="px-3 py-1.5 flex flex-col justify-center">
<div class="flex items-center gap-2">
<div class="flex-shrink-0">
<img
src={model?.info?.meta?.profile_image_url ?? '/favicon.png'}
alt={model.name}
class="size-5 rounded-full object-cover shrink-0"
/>
</div>
<div class="font-medium text-gray-800 dark:text-gray-200 pr-4">
{model.name}
</div>
</div>
</td>
<td class="px-3 py-1.5 text-right font-medium text-gray-900 dark:text-white w-max">
{model.rating}
</td>
<td class=" px-3 py-1.5 text-right font-semibold text-green-500">
<div class=" w-10">
{#if model.stats.won === '-'}
-
{:else}
<span class="hidden group-hover:inline"
>{((model.stats.won / model.stats.count) * 100).toFixed(1)}%</span
>
<span class=" group-hover:hidden">{model.stats.won}</span>
{/if}
</div>
</td>
<td class="px-3 py-1.5 text-right font-semibold text-red-500">
<div class=" w-10">
{#if model.stats.lost === '-'}
-
{:else}
<span class="hidden group-hover:inline"
>{((model.stats.lost / model.stats.count) * 100).toFixed(1)}%</span
>
<span class=" group-hover:hidden">{model.stats.lost}</span>
{/if}
</div>
</td>
</tr>
{/each}
</tbody>
</table>
{/if}
</div>
<div class=" text-gray-500 text-xs mt-1.5 w-full flex justify-end">
<div class=" text-right">
<div class="line-clamp-1">
{$i18n.t(
'The evaluation leaderboard is based on the Elo rating system and is updated in real-time.'
)}
</div>
{$i18n.t(
'The leaderboard is currently in beta, and we may adjust the rating calculations as we refine the algorithm.'
)}
</div>
</div>
<div class="pb-4"></div>
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
<div class="flex md:self-center text-lg font-medium px-0.5">
{$i18n.t('Feedback History')}
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" />
<span class="text-lg font-medium text-gray-500 dark:text-gray-300">{feedbacks.length}</span>
</div>
</div>
<div
class="scrollbar-hidden relative whitespace-nowrap overflow-x-auto max-w-full rounded pt-0.5"
>
{#if (feedbacks ?? []).length === 0}
<div class="text-center text-xs text-gray-500 dark:text-gray-400 py-1">
{$i18n.t('No feedbacks found')}
</div>
{:else}
<table
class="w-full text-sm text-left text-gray-500 dark:text-gray-400 table-auto max-w-full rounded"
>
<thead
class="text-xs text-gray-700 uppercase bg-gray-50 dark:bg-gray-850 dark:text-gray-400 -translate-y-0.5"
>
<tr class="">
<th scope="col" class="px-3 text-right cursor-pointer select-none w-0">
{$i18n.t('User')}
</th>
<th scope="col" class="px-3 pr-1.5 cursor-pointer select-none">
{$i18n.t('Models')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-fit">
{$i18n.t('Result')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-0">
{$i18n.t('Updated At')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-0"> </th>
</tr>
</thead>
<tbody class="">
{#each paginatedFeedbacks as feedback (feedback.id)}
<tr class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs">
<td class=" py-0.5 text-right font-semibold">
<div class="flex justify-center">
<Tooltip content={feedback?.user?.name}>
<div class="flex-shrink-0">
<img
src={feedback?.user?.profile_image_url ?? '/user.png'}
alt={feedback?.user?.name}
class="size-5 rounded-full object-cover shrink-0"
/>
</div>
</Tooltip>
</div>
</td>
<td class=" py-1 pl-3 flex flex-col">
<div class="flex flex-col items-start gap-0.5 h-full">
<div class="flex flex-col h-full">
{#if feedback.data?.sibling_model_ids}
<div class="font-semibold text-gray-600 dark:text-gray-400 flex-1">
{feedback.data?.model_id}
</div>
<Tooltip content={feedback.data.sibling_model_ids.join(', ')}>
<div class=" text-[0.65rem] text-gray-600 dark:text-gray-400 line-clamp-1">
{#if feedback.data.sibling_model_ids.length > 2}
<!-- {$i18n.t('and {{COUNT}} more')} -->
{feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t(
'and {{COUNT}} more',
{ COUNT: feedback.data.sibling_model_ids.length - 2 }
)}
{:else}
{feedback.data.sibling_model_ids.join(', ')}
{/if}
</div>
</Tooltip>
{:else}
<div
class=" text-sm font-medium text-gray-600 dark:text-gray-400 flex-1 py-1.5"
>
{feedback.data?.model_id}
</div>
{/if}
</div>
</div>
</td>
<td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max">
<div class=" flex justify-end">
{#if feedback.data.rating.toString() === '1'}
<Badge type="info" content={$i18n.t('Won')} />
{:else if feedback.data.rating.toString() === '0'}
<Badge type="muted" content={$i18n.t('Draw')} />
{:else if feedback.data.rating.toString() === '-1'}
<Badge type="error" content={$i18n.t('Lost')} />
{/if}
</div>
</td>
<td class=" px-3 py-1 text-right font-medium">
{dayjs(feedback.updated_at * 1000).fromNow()}
</td>
<td class=" px-3 py-1 text-right font-semibold">
<FeedbackMenu
on:delete={(e) => {
deleteFeedbackHandler(feedback.id);
}}
>
<button
class="self-center w-fit text-sm p-1.5 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
>
<EllipsisHorizontal />
</button>
</FeedbackMenu>
</td>
</tr>
{/each}
</tbody>
</table>
{/if}
</div>
{#if feedbacks.length > 0}
<div class=" flex flex-col justify-end w-full text-right gap-1">
<div class="line-clamp-1 text-gray-500 text-xs">
{$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')}
</div>
<div class="flex space-x-1 ml-auto">
<Tooltip
content={$i18n.t(
'To protect your privacy, only ratings, model IDs, tags, and metadata are shared from your feedback—your chat logs remain private and are not included.'
)}
>
<button
class="flex text-xs items-center px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-200 transition"
on:click={async () => {
shareHandler();
}}
>
<div class=" self-center mr-2 font-medium line-clamp-1">
{$i18n.t('Share to OpenWebUI Community')}
</div>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-3.5 h-3.5"
>
<path
fill-rule="evenodd"
d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 9.5a.75.75 0 0 1-.75-.75V8.06l-.72.72a.75.75 0 0 1-1.06-1.06l2-2a.75.75 0 0 1 1.06 0l2 2a.75.75 0 1 1-1.06 1.06l-.72-.72v2.69a.75.75 0 0 1-.75.75Z"
clip-rule="evenodd"
/>
</svg>
</div>
</button>
</Tooltip>
</div>
</div>
{/if}
{#if feedbacks.length > 10}
<Pagination bind:page count={feedbacks.length} perPage={10} />
{/if}
<div class="pb-12"></div>
{/if}

View file

@ -0,0 +1,46 @@
<script lang="ts">
import { DropdownMenu } from 'bits-ui';
import { flyAndScale } from '$lib/utils/transitions';
import { getContext, createEventDispatcher } from 'svelte';
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import Dropdown from '$lib/components/common/Dropdown.svelte';
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
import Pencil from '$lib/components/icons/Pencil.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Download from '$lib/components/icons/Download.svelte';
let show = false;
</script>
<Dropdown bind:show on:change={(e) => {}}>
<Tooltip content={$i18n.t('More')}>
<slot />
</Tooltip>
<div slot="content">
<DropdownMenu.Content
class="w-full max-w-[150px] rounded-xl px-1 py-1.5 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-lg"
sideOffset={-2}
side="bottom"
align="start"
transition={flyAndScale}
>
<DropdownMenu.Item
class="flex gap-2 items-center px-3 py-1.5 text-sm cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
on:click={() => {
dispatch('delete');
show = false;
}}
>
<GarbageBin strokeWidth="2" />
<div class="flex items-center">{$i18n.t('Delete')}</div>
</DropdownMenu.Item>
</DropdownMenu.Content>
</div>
</Dropdown>

View file

@ -17,6 +17,9 @@
import WebSearch from './Settings/WebSearch.svelte';
import { config } from '$lib/stores';
import { getBackendConfig } from '$lib/apis';
import ChartBar from '../icons/ChartBar.svelte';
import DocumentChartBar from '../icons/DocumentChartBar.svelte';
import Evaluations from './Settings/Evaluations.svelte';
const i18n = getContext('i18n');
@ -42,7 +45,7 @@
class="tabs flex flex-row overflow-x-auto space-x-1 max-w-full lg:space-x-0 lg:space-y-1 lg:flex-col lg:flex-none lg:w-44 dark:text-gray-200 text-xs text-left scrollbar-none"
>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 lg:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 lg:flex-none flex text-right transition {selectedTab ===
'general'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -68,7 +71,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'users'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -92,7 +95,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'connections'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -116,7 +119,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'models'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -142,7 +145,22 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'evaluations'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'evaluations';
}}
>
<div class=" self-center mr-2">
<DocumentChartBar />
</div>
<div class=" self-center">{$i18n.t('Evaluations')}</div>
</button>
<button
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'documents'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -172,7 +190,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'web'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -196,7 +214,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'interface'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -222,7 +240,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'audio'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -249,7 +267,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'images'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -275,7 +293,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'pipelines'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -305,7 +323,7 @@
</button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'db'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -357,6 +375,8 @@
/>
{:else if selectedTab === 'models'}
<Models />
{:else if selectedTab === 'evaluations'}
<Evaluations />
{:else if selectedTab === 'documents'}
<Documents
on:save={async () => {

View file

@ -38,6 +38,9 @@
let STT_OPENAI_API_KEY = '';
let STT_ENGINE = '';
let STT_MODEL = '';
let STT_WHISPER_MODEL = '';
let STT_WHISPER_MODEL_LOADING = false;
// eslint-disable-next-line no-undef
let voices: SpeechSynthesisVoice[] = [];
@ -99,18 +102,23 @@
OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL,
OPENAI_API_KEY: STT_OPENAI_API_KEY,
ENGINE: STT_ENGINE,
MODEL: STT_MODEL
MODEL: STT_MODEL,
WHISPER_MODEL: STT_WHISPER_MODEL
}
});
if (res) {
saveHandler();
getBackendConfig()
.then(config.set)
.catch(() => {});
config.set(await getBackendConfig());
}
};
const sttModelUpdateHandler = async () => {
STT_WHISPER_MODEL_LOADING = true;
await updateConfigHandler();
STT_WHISPER_MODEL_LOADING = false;
};
onMount(async () => {
const res = await getAudioConfig(localStorage.token);
@ -134,6 +142,7 @@
STT_ENGINE = res.stt.ENGINE;
STT_MODEL = res.stt.MODEL;
STT_WHISPER_MODEL = res.stt.WHISPER_MODEL;
}
await getVoices();
@ -201,6 +210,88 @@
</div>
</div>
</div>
{:else if STT_ENGINE === ''}
<div>
<div class=" mb-1.5 text-sm font-medium">{$i18n.t('STT Model')}</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Set whisper model')}
bind:value={STT_WHISPER_MODEL}
/>
</div>
<button
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
sttModelUpdateHandler();
}}
disabled={STT_WHISPER_MODEL_LOADING}
>
{#if STT_WHISPER_MODEL_LOADING}
<div class="self-center">
<svg
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
>
<style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style>
<path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/>
<path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/>
</svg>
</div>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
/>
<path
d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
/>
</svg>
{/if}
</button>
</div>
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(`Open WebUI uses faster-whisper internally.`)}
<a
class=" hover:underline dark:text-gray-200 text-gray-800"
href="https://github.com/SYSTRAN/faster-whisper"
target="_blank"
>
{$i18n.t(
`Click here to learn more about faster-whisper and see the available models.`
)}
</a>
</div>
</div>
{/if}
</div>
@ -339,7 +430,7 @@
<datalist id="tts-model-list">
{#each models as model}
<option value={model.id} />
<option value={model.id} class="bg-gray-50 dark:bg-gray-700" />
{/each}
</datalist>
</div>
@ -380,7 +471,7 @@
<datalist id="tts-model-list">
{#each models as model}
<option value={model.id} />
<option value={model.id} class="bg-gray-50 dark:bg-gray-700" />
{/each}
</datalist>
</div>
@ -460,7 +551,7 @@
</div>
<div class="flex justify-end text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}

View file

@ -439,7 +439,7 @@
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}

View file

@ -26,6 +26,9 @@
import ResetVectorDBConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Switch from '$lib/components/common/Switch.svelte';
import { text } from '@sveltejs/kit';
import Textarea from '$lib/components/common/Textarea.svelte';
const i18n = getContext('i18n');
@ -38,6 +41,7 @@
let embeddingEngine = '';
let embeddingModel = '';
let embeddingBatchSize = 1;
let rerankingModel = '';
let fileMaxSize = null;
@ -47,13 +51,13 @@
let tikaServerUrl = '';
let showTikaServerUrl = false;
let textSplitter = '';
let chunkSize = 0;
let chunkOverlap = 0;
let pdfExtractImages = true;
let OpenAIKey = '';
let OpenAIUrl = '';
let OpenAIBatchSize = 1;
let querySettings = {
template: '',
@ -100,12 +104,16 @@
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_model: embeddingModel,
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
? {
embedding_batch_size: embeddingBatchSize
}
: {}),
...(embeddingEngine === 'openai'
? {
openai_config: {
key: OpenAIKey,
url: OpenAIUrl,
batch_size: OpenAIBatchSize
url: OpenAIUrl
}
}
: {})
@ -173,6 +181,7 @@
max_count: fileMaxCount === '' ? null : fileMaxCount
},
chunk: {
text_splitter: textSplitter,
chunk_overlap: chunkOverlap,
chunk_size: chunkSize
},
@ -193,10 +202,10 @@
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
OpenAIKey = embeddingConfig.openai_config.key;
OpenAIUrl = embeddingConfig.openai_config.url;
OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1;
}
};
@ -218,11 +227,13 @@
await setRerankingConfig();
querySettings = await getQuerySettings(localStorage.token);
const res = await getRAGConfig(localStorage.token);
if (res) {
pdfExtractImages = res.pdf_extract_images;
textSplitter = res.chunk.text_splitter;
chunkSize = res.chunk.chunk_size;
chunkOverlap = res.chunk.chunk_overlap;
@ -309,6 +320,8 @@
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
</div>
{/if}
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
<div class="flex mt-0.5 space-x-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
<div class=" flex-1">
@ -318,13 +331,13 @@
min="1"
max="2048"
step="1"
bind:value={OpenAIBatchSize}
bind:value={embeddingBatchSize}
class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
/>
</div>
<div class="">
<input
bind:value={OpenAIBatchSize}
bind:value={embeddingBatchSize}
type="number"
class=" bg-transparent text-center w-14"
min="-2"
@ -529,13 +542,13 @@
<hr class=" dark:border-gray-850" />
<div class="">
<div class="text-sm font-medium">{$i18n.t('Content Extraction')}</div>
<div class="text-sm font-medium mb-1">{$i18n.t('Content Extraction')}</div>
<div class="flex w-full justify-between mt-2">
<div class="flex w-full justify-between">
<div class="self-center text-xs font-medium">{$i18n.t('Engine')}</div>
<div class="flex items-center relative">
<select
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 text-xs bg-transparent outline-none text-right"
bind:value={contentExtractionEngine}
on:change={(e) => {
showTikaServerUrl = e.target.value === 'tika';
@ -548,7 +561,7 @@
</div>
{#if showTikaServerUrl}
<div class="flex w-full mt-2">
<div class="flex w-full mt-1">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
@ -562,10 +575,137 @@
<hr class=" dark:border-gray-850" />
<div class="">
<div class="text-sm font-medium">{$i18n.t('Files')}</div>
<div class=" ">
<div class=" text-sm font-medium mb-1">{$i18n.t('Query Params')}</div>
<div class=" my-2 flex gap-1.5">
<div class=" flex gap-1.5">
<div class="flex flex-col w-full gap-1">
<div class=" text-xs font-medium w-full">{$i18n.t('Top K')}</div>
<div class="w-full">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Top K')}
bind:value={querySettings.k}
autocomplete="off"
min="0"
/>
</div>
</div>
{#if querySettings.hybrid === true}
<div class=" flex flex-col w-full gap-1">
<div class="text-xs font-medium w-full">
{$i18n.t('Minimum Score')}
</div>
<div class="w-full">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
step="0.01"
placeholder={$i18n.t('Enter Score')}
bind:value={querySettings.r}
autocomplete="off"
min="0.0"
title={$i18n.t('The score should be a value between 0.0 (0%) and 1.0 (100%).')}
/>
</div>
</div>
{/if}
</div>
{#if querySettings.hybrid === true}
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(
'Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.'
)}
</div>
{/if}
<div class="mt-2">
<div class=" mb-1 text-xs font-medium">{$i18n.t('RAG Template')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={querySettings.template}
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
/>
</Tooltip>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div class=" ">
<div class="mb-1 text-sm font-medium">{$i18n.t('Chunk Params')}</div>
<div class="flex w-full justify-between mb-1.5">
<div class="self-center text-xs font-medium">{$i18n.t('Text Splitter')}</div>
<div class="flex items-center relative">
<select
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 text-xs bg-transparent outline-none text-right"
bind:value={textSplitter}
>
<option value="">{$i18n.t('Default')} ({$i18n.t('Character')})</option>
<option value="token">{$i18n.t('Token')} ({$i18n.t('Tiktoken')})</option>
</select>
</div>
</div>
<div class=" flex gap-1.5">
<div class=" w-full justify-between">
<div class="self-center text-xs font-medium min-w-fit mb-1">{$i18n.t('Chunk Size')}</div>
<div class="self-center">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Size')}
bind:value={chunkSize}
autocomplete="off"
min="0"
/>
</div>
</div>
<div class="w-full">
<div class=" self-center text-xs font-medium min-w-fit mb-1">
{$i18n.t('Chunk Overlap')}
</div>
<div class="self-center">
<input
class="w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Overlap')}
bind:value={chunkOverlap}
autocomplete="off"
min="0"
/>
</div>
</div>
</div>
<div class="my-2">
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">{$i18n.t('PDF Extract Images (OCR)')}</div>
<div>
<Switch bind:state={pdfExtractImages} />
</div>
</div>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div class="">
<div class="text-sm font-medium mb-1">{$i18n.t('Files')}</div>
<div class=" flex gap-1.5">
<div class="w-full">
<div class=" self-center text-xs font-medium min-w-fit mb-1">
{$i18n.t('Max Upload Size')}
@ -617,128 +757,6 @@
<hr class=" dark:border-gray-850" />
<div class=" ">
<div class=" text-sm font-medium">{$i18n.t('Query Params')}</div>
<div class=" flex gap-1">
<div class=" flex w-full justify-between">
<div class="self-center text-xs font-medium min-w-fit">{$i18n.t('Top K')}</div>
<div class="self-center p-3">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Top K')}
bind:value={querySettings.k}
autocomplete="off"
min="0"
/>
</div>
</div>
{#if querySettings.hybrid === true}
<div class=" flex w-full justify-between">
<div class=" self-center text-xs font-medium min-w-fit">
{$i18n.t('Minimum Score')}
</div>
<div class="self-center p-3">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
step="0.01"
placeholder={$i18n.t('Enter Score')}
bind:value={querySettings.r}
autocomplete="off"
min="0.0"
title={$i18n.t('The score should be a value between 0.0 (0%) and 1.0 (100%).')}
/>
</div>
</div>
{/if}
</div>
{#if querySettings.hybrid === true}
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(
'Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.'
)}
</div>
<hr class=" dark:border-gray-850 my-3" />
{/if}
<div>
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<textarea
bind:value={querySettings.template}
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
class="w-full rounded-lg px-4 py-3 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="4"
/>
</Tooltip>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div class=" ">
<div class=" text-sm font-medium">{$i18n.t('Chunk Params')}</div>
<div class=" my-2 flex gap-1.5">
<div class=" w-full justify-between">
<div class="self-center text-xs font-medium min-w-fit mb-1">{$i18n.t('Chunk Size')}</div>
<div class="self-center">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Size')}
bind:value={chunkSize}
autocomplete="off"
min="0"
/>
</div>
</div>
<div class="w-full">
<div class=" self-center text-xs font-medium min-w-fit mb-1">
{$i18n.t('Chunk Overlap')}
</div>
<div class="self-center">
<input
class="w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Overlap')}
bind:value={chunkOverlap}
autocomplete="off"
min="0"
/>
</div>
</div>
</div>
<div class="my-3">
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">{$i18n.t('PDF Extract Images (OCR)')}</div>
<button
class=" text-xs font-medium text-gray-500"
type="button"
on:click={() => {
pdfExtractImages = !pdfExtractImages;
}}>{pdfExtractImages ? $i18n.t('On') : $i18n.t('Off')}</button
>
</div>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div>
<button
class=" flex rounded-xl py-2 px-3.5 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
@ -788,13 +806,15 @@
/>
</svg>
</div>
<div class=" self-center text-sm font-medium">{$i18n.t('Reset Vector Storage')}</div>
<div class=" self-center text-sm font-medium">
{$i18n.t('Reset Vector Storage/Knowledge')}
</div>
</button>
</div>
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}

View file

@ -0,0 +1,158 @@
<script lang="ts">
import { toast } from 'svelte-sonner';
import { models, user } from '$lib/stores';
import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
const dispatch = createEventDispatcher();
import { getModels } from '$lib/apis';
import Switch from '$lib/components/common/Switch.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Plus from '$lib/components/icons/Plus.svelte';
import Model from './Evaluations/Model.svelte';
import ModelModal from './Evaluations/ModelModal.svelte';
import { getConfig, updateConfig } from '$lib/apis/evaluations';
const i18n = getContext('i18n');
let config = null;
let showAddModel = false;
const submitHandler = async () => {
config = await updateConfig(localStorage.token, config).catch((err) => {
toast.error(err);
return null;
});
if (config) {
toast.success('Settings saved successfully');
}
};
const addModelHandler = async (model) => {
config.EVALUATION_ARENA_MODELS.push(model);
config.EVALUATION_ARENA_MODELS = [...config.EVALUATION_ARENA_MODELS];
await submitHandler();
models.set(await getModels(localStorage.token));
};
const editModelHandler = async (model, modelIdx) => {
config.EVALUATION_ARENA_MODELS[modelIdx] = model;
config.EVALUATION_ARENA_MODELS = [...config.EVALUATION_ARENA_MODELS];
await submitHandler();
models.set(await getModels(localStorage.token));
};
const deleteModelHandler = async (modelIdx) => {
config.EVALUATION_ARENA_MODELS = config.EVALUATION_ARENA_MODELS.filter(
(m, mIdx) => mIdx !== modelIdx
);
await submitHandler();
models.set(await getModels(localStorage.token));
};
onMount(async () => {
if ($user.role === 'admin') {
config = await getConfig(localStorage.token).catch((err) => {
toast.error(err);
return null;
});
}
});
</script>
<ModelModal
bind:show={showAddModel}
on:submit={async (e) => {
addModelHandler(e.detail);
}}
/>
<form
class="flex flex-col h-full justify-between text-sm"
on:submit|preventDefault={() => {
submitHandler();
dispatch('save');
}}
>
<div class="overflow-y-scroll scrollbar-hidden h-full">
{#if config !== null}
<div class="">
<div class="text-sm font-medium mb-2">{$i18n.t('General Settings')}</div>
<div class=" mb-2">
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">{$i18n.t('Arena Models')}</div>
<Tooltip content={$i18n.t(`Message rating should be enabled to use this feature`)}>
<Switch bind:state={config.ENABLE_EVALUATION_ARENA_MODELS} />
</Tooltip>
</div>
</div>
{#if config.ENABLE_EVALUATION_ARENA_MODELS}
<hr class=" border-gray-50 dark:border-gray-700/10 my-2" />
<div class="flex justify-between items-center mb-2">
<div class="text-sm font-medium">{$i18n.t('Manage Arena Models')}</div>
<div>
<Tooltip content={$i18n.t('Add Arena Model')}>
<button
class="p-1"
type="button"
on:click={() => {
showAddModel = true;
}}
>
<Plus />
</button>
</Tooltip>
</div>
</div>
<div class="flex flex-col gap-2">
{#if (config?.EVALUATION_ARENA_MODELS ?? []).length > 0}
{#each config.EVALUATION_ARENA_MODELS as model, index}
<Model
{model}
on:edit={(e) => {
editModelHandler(e.detail, index);
}}
on:delete={(e) => {
deleteModelHandler(index);
}}
/>
{/each}
{:else}
<div class=" text-center text-xs text-gray-500">
{$i18n.t(
`Using the default arena model with all models. Click the plus button to add custom models.`
)}
</div>
{/if}
</div>
{/if}
</div>
{:else}
<div class="flex h-full justify-center">
<div class="my-auto">
<Spinner className="size-6" />
</div>
</div>
{/if}
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}
</button>
</div>
</form>

View file

@ -0,0 +1,63 @@
<script lang="ts">
import { getContext, createEventDispatcher } from 'svelte';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import Cog6 from '$lib/components/icons/Cog6.svelte';
import ModelModal from './ModelModal.svelte';
export let model;
let showModel = false;
</script>
<ModelModal
bind:show={showModel}
edit={true}
{model}
on:submit={async (e) => {
dispatch('edit', e.detail);
}}
on:delete={async () => {
dispatch('delete');
}}
/>
<div class="py-0.5">
<div class="flex justify-between items-center mb-1">
<div class="flex flex-col flex-1">
<div class="flex gap-2.5 items-center">
<img
src={model.meta.profile_image_url}
alt={model.name}
class="size-8 rounded-full object-cover shrink-0"
/>
<div class="w-full flex flex-col">
<div class="flex items-center gap-1">
<div class="flex-shrink-0 line-clamp-1">
{model.name}
</div>
</div>
<div class="flex items-center gap-1">
<div class=" text-xs w-full text-gray-500 bg-transparent line-clamp-1">
{model?.meta?.description ?? model.id}
</div>
</div>
</div>
</div>
</div>
<div class="flex items-center">
<button
class="self-center w-fit text-sm p-1.5 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button"
on:click={() => {
showModel = true;
}}
>
<Cog6 />
</button>
</div>
</div>
</div>

View file

@ -0,0 +1,418 @@
<script>
import { createEventDispatcher, getContext, onMount } from 'svelte';
const i18n = getContext('i18n');
const dispatch = createEventDispatcher();
import Modal from '$lib/components/common/Modal.svelte';
import { models } from '$lib/stores';
import Plus from '$lib/components/icons/Plus.svelte';
import Minus from '$lib/components/icons/Minus.svelte';
import PencilSolid from '$lib/components/icons/PencilSolid.svelte';
import { toast } from 'svelte-sonner';
export let show = false;
export let edit = false;
export let model = null;
let name = '';
let id = '';
$: if (name) {
generateId();
}
const generateId = () => {
if (!edit) {
id = name
.toLowerCase()
.replace(/[^a-z0-9]/g, '-')
.replace(/-+/g, '-')
.replace(/^-|-$/g, '');
}
};
let profileImageUrl = '/favicon.png';
let description = '';
let selectedModelId = '';
let modelIds = [];
let filterMode = 'include';
let imageInputElement;
let loading = false;
const addModelHandler = () => {
if (selectedModelId) {
modelIds = [...modelIds, selectedModelId];
selectedModelId = '';
}
};
const submitHandler = () => {
loading = true;
if (!name || !id) {
loading = false;
toast.error('Name and ID are required, please fill them out');
return;
}
if (!edit) {
if ($models.find((model) => model.name === name)) {
loading = false;
name = '';
toast.error('Model name already exists, please choose a different one');
return;
}
}
const model = {
id: id,
name: name,
meta: {
profile_image_url: profileImageUrl,
description: description || null,
model_ids: modelIds.length > 0 ? modelIds : null,
filter_mode: modelIds.length > 0 ? (filterMode ? filterMode : null) : null
}
};
dispatch('submit', model);
loading = false;
show = false;
name = '';
id = '';
profileImageUrl = '/favicon.png';
description = '';
modelIds = [];
selectedModelId = '';
};
const initModel = () => {
if (model) {
name = model.name;
id = model.id;
profileImageUrl = model.meta.profile_image_url;
description = model.meta.description;
modelIds = model.meta.model_ids || [];
filterMode = model.meta?.filter_mode ?? 'include';
}
};
$: if (show) {
initModel();
}
onMount(() => {
initModel();
});
</script>
<Modal size="sm" bind:show>
<div>
<div class=" flex justify-between dark:text-gray-100 px-5 pt-4 pb-2">
<div class=" text-lg font-medium self-center font-primary">
{#if edit}
{$i18n.t('Edit Arena Model')}
{:else}
{$i18n.t('Add Arena Model')}
{/if}
</div>
<button
class="self-center"
on:click={() => {
show = false;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button>
</div>
<div class="flex flex-col md:flex-row w-full px-4 pb-4 md:space-x-4 dark:text-gray-200">
<div class=" flex flex-col w-full sm:flex-row sm:justify-center sm:space-x-6">
<form
class="flex flex-col w-full"
on:submit|preventDefault={() => {
submitHandler();
}}
>
<div class="px-1">
<div class="flex justify-center pb-3">
<input
bind:this={imageInputElement}
type="file"
hidden
accept="image/*"
on:change={(e) => {
const files = e.target.files ?? [];
let reader = new FileReader();
reader.onload = (event) => {
let originalImageUrl = `${event.target.result}`;
const img = new Image();
img.src = originalImageUrl;
img.onload = function () {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
// Calculate the aspect ratio of the image
const aspectRatio = img.width / img.height;
// Calculate the new width and height to fit within 250x250
let newWidth, newHeight;
if (aspectRatio > 1) {
newWidth = 250 * aspectRatio;
newHeight = 250;
} else {
newWidth = 250;
newHeight = 250 / aspectRatio;
}
// Set the canvas size
canvas.width = 250;
canvas.height = 250;
// Calculate the position to center the image
const offsetX = (250 - newWidth) / 2;
const offsetY = (250 - newHeight) / 2;
// Draw the image on the canvas
ctx.drawImage(img, offsetX, offsetY, newWidth, newHeight);
// Get the base64 representation of the compressed image
const compressedSrc = canvas.toDataURL('image/jpeg');
// Display the compressed image
profileImageUrl = compressedSrc;
e.target.files = null;
};
};
if (
files.length > 0 &&
['image/gif', 'image/webp', 'image/jpeg', 'image/png'].includes(
files[0]['type']
)
) {
reader.readAsDataURL(files[0]);
}
}}
/>
<button
class="relative rounded-full w-fit h-fit shrink-0"
type="button"
on:click={() => {
imageInputElement.click();
}}
>
<img
src={profileImageUrl}
class="size-16 rounded-full object-cover shrink-0"
alt="Profile"
/>
<div
class="absolute flex justify-center rounded-full bottom-0 left-0 right-0 top-0 h-full w-full overflow-hidden bg-gray-700 bg-fixed opacity-0 transition duration-300 ease-in-out hover:opacity-50"
>
<div class="my-auto text-white">
<PencilSolid className="size-4" />
</div>
</div>
</button>
</div>
<div class="flex gap-2">
<div class="flex flex-col w-full">
<div class=" mb-0.5 text-xs text-gray-500">{$i18n.t('Name')}</div>
<div class="flex-1">
<input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
type="text"
bind:value={name}
placeholder={$i18n.t('Model Name')}
autocomplete="off"
required
/>
</div>
</div>
<div class="flex flex-col w-full">
<div class=" mb-0.5 text-xs text-gray-500">{$i18n.t('ID')}</div>
<div class="flex-1">
<input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
type="text"
bind:value={id}
placeholder={$i18n.t('Model ID')}
autocomplete="off"
required
disabled={edit}
/>
</div>
</div>
</div>
<div class="flex flex-col w-full mt-2">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Description')}</div>
<div class="flex-1">
<input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
type="text"
bind:value={description}
placeholder={$i18n.t('Enter description')}
autocomplete="off"
/>
</div>
</div>
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div class="flex flex-col w-full">
<div class="mb-1 flex justify-between">
<div class="text-xs text-gray-500">{$i18n.t('Models')}</div>
<div>
<button
class=" text-xs text-gray-500"
type="button"
on:click={() => {
filterMode = filterMode === 'include' ? 'exclude' : 'include';
}}
>
{#if filterMode === 'include'}
{$i18n.t('Include')}
{:else}
{$i18n.t('Exclude')}
{/if}
</button>
</div>
</div>
{#if modelIds.length > 0}
<div class="flex flex-col">
{#each modelIds as modelId, modelIdx}
<div class=" flex gap-2 w-full justify-between items-center">
<div class=" text-sm flex-1 py-1 rounded-lg">
{$models.find((model) => model.id === modelId)?.name}
</div>
<div class="flex-shrink-0">
<button
type="button"
on:click={() => {
modelIds = modelIds.filter((_, idx) => idx !== modelIdx);
}}
>
<Minus strokeWidth="2" className="size-3.5" />
</button>
</div>
</div>
{/each}
</div>
{:else}
<div class="text-gray-500 text-xs text-center py-2">
{$i18n.t('Leave empty to include all models or select specific models')}
</div>
{/if}
</div>
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div class="flex items-center">
<select
class="w-full py-1 text-sm rounded-lg bg-transparent {selectedModelId
? ''
: 'text-gray-500'} placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
bind:value={selectedModelId}
>
<option value="">{$i18n.t('Select a model')}</option>
{#each $models.filter((m) => m?.owned_by !== 'arena') as model}
<option value={model.id} class="bg-gray-50 dark:bg-gray-700">{model.name}</option>
{/each}
</select>
<div>
<button
type="button"
on:click={() => {
addModelHandler();
}}
>
<Plus className="size-3.5" strokeWidth="2" />
</button>
</div>
</div>
</div>
<div class="flex justify-end pt-3 text-sm font-medium gap-1.5">
{#if edit}
<button
class="px-3.5 py-1.5 text-sm font-medium dark:bg-black dark:hover:bg-gray-900 dark:text-white bg-white text-black hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center"
type="button"
on:click={() => {
dispatch('delete', model);
show = false;
}}
>
{$i18n.t('Delete')}
</button>
{/if}
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed'
: ''}"
type="submit"
disabled={loading}
>
{$i18n.t('Save')}
{#if loading}
<div class="ml-2 self-center">
<svg
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div>
{/if}
</button>
</div>
</form>
</div>
</div>
</div>
</Modal>

View file

@ -137,7 +137,7 @@
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}

View file

@ -648,7 +648,7 @@
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed'
: ''}"
type="submit"

View file

@ -14,6 +14,7 @@
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Switch from '$lib/components/common/Switch.svelte';
import Textarea from '$lib/components/common/Textarea.svelte';
const dispatch = createEventDispatcher();
@ -23,6 +24,7 @@
TASK_MODEL: '',
TASK_MODEL_EXTERNAL: '',
TITLE_GENERATION_PROMPT_TEMPLATE: '',
TAG_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_SEARCH_QUERY: true,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: ''
};
@ -60,7 +62,7 @@
>
<div class=" overflow-y-scroll scrollbar-hidden h-full pr-1.5">
<div>
<div class=" mb-2.5 text-sm font-medium flex">
<div class=" mb-2.5 text-sm font-medium flex items-center">
<div class=" mr-1">{$i18n.t('Set Task Model')}</div>
<Tooltip
content={$i18n.t(
@ -73,7 +75,7 @@
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-5 h-5"
class="size-3.5"
>
<path
stroke-linecap="round"
@ -124,10 +126,22 @@
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<textarea
<Textarea
bind:value={taskConfig.TITLE_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="3"
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
/>
</Tooltip>
</div>
<div class="mt-3">
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Tags Generation Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={taskConfig.TAG_GENERATION_PROMPT_TEMPLATE}
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
/>
</Tooltip>
@ -151,10 +165,8 @@
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<textarea
<Textarea
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="3"
placeholder={$i18n.t(
'Leave empty to use the default prompt, or enter a custom prompt'
)}
@ -349,7 +361,7 @@
<div class="flex justify-end text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}

View file

@ -763,7 +763,7 @@
<option value="" disabled selected>{$i18n.t('Select a model')}</option>
{/if}
{#each $models.filter((m) => !(m?.preset ?? false) && m.owned_by === 'ollama' && (selectedOllamaUrlIdx === null ? true : (m?.ollama?.urls ?? []).includes(selectedOllamaUrlIdx))) as model}
<option value={model.name} class="bg-gray-50 dark:bg-gray-700"
<option value={model.id} class="bg-gray-50 dark:bg-gray-700"
>{model.name +
' (' +
(model.ollama.size / 1024 ** 3).toFixed(1) +

View file

@ -546,12 +546,14 @@
{/if}
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
type="submit"
>
{$i18n.t('Save')}
</button>
</div>
{#if PIPELINES_LIST !== null && PIPELINES_LIST.length > 0}
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}
</button>
</div>
{/if}
</form>

View file

@ -24,9 +24,17 @@
}
};
let chatDeletion = true;
let chatEdit = true;
let chatTemporary = true;
onMount(async () => {
permissions = await getUserPermissions(localStorage.token);
chatDeletion = permissions?.chat?.deletion ?? true;
chatEdit = permissions?.chat?.editing ?? true;
chatTemporary = permissions?.chat?.temporary ?? true;
const res = await getModelFilterConfig(localStorage.token);
if (res) {
whitelistEnabled = res.enabled;
@ -43,7 +51,13 @@
// console.log('submit');
await setDefaultModels(localStorage.token, defaultModelId);
await updateUserPermissions(localStorage.token, permissions);
await updateUserPermissions(localStorage.token, {
chat: {
deletion: chatDeletion,
editing: chatEdit,
temporary: chatTemporary
}
});
await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels);
saveHandler();
@ -54,127 +68,22 @@
<div>
<div class=" mb-2 text-sm font-medium">{$i18n.t('User Permissions')}</div>
<div class=" flex w-full justify-between">
<div class=" flex w-full justify-between my-2 pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Deletion')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
permissions.chat.deletion = !(permissions?.chat?.deletion ?? true);
}}
type="button"
>
{#if permissions?.chat?.deletion ?? true}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M11.5 1A3.5 3.5 0 0 0 8 4.5V7H2.5A1.5 1.5 0 0 0 1 8.5v5A1.5 1.5 0 0 0 2.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 9.5 7V4.5a2 2 0 1 1 4 0v1.75a.75.75 0 0 0 1.5 0V4.5A3.5 3.5 0 0 0 11.5 1Z"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t('Allow')}</span>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M8 1a3.5 3.5 0 0 0-3.5 3.5V7A1.5 1.5 0 0 0 3 8.5v5A1.5 1.5 0 0 0 4.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 11.5 7V4.5A3.5 3.5 0 0 0 8 1Zm2 6V4.5a2 2 0 1 0-4 0V7h4Z"
clip-rule="evenodd"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t("Don't Allow")}</span>
{/if}
</button>
<Switch bind:state={chatDeletion} />
</div>
<div class=" flex w-full justify-between">
<div class=" flex w-full justify-between my-2 pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Editing')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
permissions.chat.editing = !(permissions?.chat?.editing ?? true);
}}
type="button"
>
{#if permissions?.chat?.editing ?? true}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M11.5 1A3.5 3.5 0 0 0 8 4.5V7H2.5A1.5 1.5 0 0 0 1 8.5v5A1.5 1.5 0 0 0 2.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 9.5 7V4.5a2 2 0 1 1 4 0v1.75a.75.75 0 0 0 1.5 0V4.5A3.5 3.5 0 0 0 11.5 1Z"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t('Allow')}</span>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M8 1a3.5 3.5 0 0 0-3.5 3.5V7A1.5 1.5 0 0 0 3 8.5v5A1.5 1.5 0 0 0 4.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 11.5 7V4.5A3.5 3.5 0 0 0 8 1Zm2 6V4.5a2 2 0 1 0-4 0V7h4Z"
clip-rule="evenodd"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t("Don't Allow")}</span>
{/if}
</button>
<Switch bind:state={chatEdit} />
</div>
<div class=" flex w-full justify-between">
<div class=" flex w-full justify-between my-2 pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Allow Temporary Chat')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
permissions.chat.temporary = !(permissions?.chat?.temporary ?? true);
}}
type="button"
>
{#if permissions?.chat?.temporary ?? true}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M11.5 1A3.5 3.5 0 0 0 8 4.5V7H2.5A1.5 1.5 0 0 0 1 8.5v5A1.5 1.5 0 0 0 2.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 9.5 7V4.5a2 2 0 1 1 4 0v1.75a.75.75 0 0 0 1.5 0V4.5A3.5 3.5 0 0 0 11.5 1Z"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t('Allow')}</span>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M8 1a3.5 3.5 0 0 0-3.5 3.5V7A1.5 1.5 0 0 0 3 8.5v5A1.5 1.5 0 0 0 4.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 11.5 7V4.5A3.5 3.5 0 0 0 8 1Zm2 6V4.5a2 2 0 1 0-4 0V7h4Z"
clip-rule="evenodd"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t("Don't Allow")}</span>
{/if}
</button>
<Switch bind:state={chatTemporary} />
</div>
</div>
@ -210,7 +119,7 @@
<div class=" space-y-1">
<div class="mb-2">
<div class="flex justify-between items-center text-xs">
<div class="flex justify-between items-center text-xs my-3 pr-2">
<div class=" text-xs font-medium">{$i18n.t('Model Whitelisting')}</div>
<Switch bind:state={whitelistEnabled} />
@ -296,7 +205,7 @@
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}

View file

@ -310,7 +310,7 @@
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}

View file

@ -10,6 +10,7 @@
import ArrowsPointingOut from '../icons/ArrowsPointingOut.svelte';
import Tooltip from '../common/Tooltip.svelte';
import SvgPanZoom from '../common/SVGPanZoom.svelte';
import ArrowLeft from '../icons/ArrowLeft.svelte';
export let overlay = false;
export let history;
@ -119,6 +120,11 @@
}
});
if (contents.length === 0) {
showControls.set(false);
showArtifacts.set(false);
}
selectedContentIdx = contents ? contents.length - 1 : 0;
};
@ -183,6 +189,17 @@
<div class=" absolute top-0 left-0 right-0 bottom-0 z-10"></div>
{/if}
<div class="absolute pointer-events-none z-50 w-full flex items-center justify-start p-4">
<button
class="self-center pointer-events-auto p-1 rounded-full bg-white dark:bg-gray-850"
on:click={() => {
showArtifacts.set(false);
}}
>
<ArrowLeft className="size-3.5 text-gray-900 dark:text-white" />
</button>
</div>
<div class=" absolute pointer-events-none z-50 w-full flex items-center justify-end p-4">
<button
class="self-center pointer-events-auto p-1 rounded-full bg-white dark:bg-gray-850"
@ -192,7 +209,7 @@
showArtifacts.set(false);
}}
>
<XMark className="size-3 text-gray-900 dark:text-white" />
<XMark className="size-3.5 text-gray-900 dark:text-white" />
</button>
</div>

View file

@ -10,7 +10,7 @@
import { goto } from '$app/navigation';
import { page } from '$app/stores';
import type { Unsubscriber, Writable } from 'svelte/store';
import { get, type Unsubscriber, type Writable } from 'svelte/store';
import type { i18n as i18nType } from 'i18next';
import { WEBUI_BASE_URL } from '$lib/constants';
@ -20,6 +20,7 @@
config,
type Model,
models,
tags as allTags,
settings,
showSidebar,
WEBUI_NAME,
@ -46,14 +47,18 @@
import { generateChatCompletion } from '$lib/apis/ollama';
import {
addTagById,
createNewChat,
deleteTagById,
deleteTagsById,
getAllTags,
getChatById,
getChatList,
getTagsById,
updateChatById
} from '$lib/apis/chats';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { processWebSearch } from '$lib/apis/retrieval';
import { processWeb, processWebSearch, processYoutubeVideo } from '$lib/apis/retrieval';
import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users';
@ -62,7 +67,8 @@
generateTitle,
generateSearchQuery,
chatAction,
generateMoACompletion
generateMoACompletion,
generateTags
} from '$lib/apis';
import Banner from '../common/Banner.svelte';
@ -78,12 +84,15 @@
let loaded = false;
const eventTarget = new EventTarget();
let controlPane;
let controlPaneComponent;
let stopResponseFlag = false;
let autoScroll = true;
let processing = '';
let messagesContainerElement: HTMLDivElement;
let navbarElement;
let showEventConfirmation = false;
let eventConfirmationTitle = '';
let eventConfirmationMessage = '';
@ -124,7 +133,7 @@
loaded = true;
window.setTimeout(() => scrollToBottom(), 0);
const chatInput = document.getElementById('chat-textarea');
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
} else {
await goto('/');
@ -174,10 +183,30 @@
message.statusHistory = [data];
}
} else if (type === 'citation') {
if (message?.citations) {
message.citations.push(data);
if (data?.type === 'code_execution') {
// Code execution; update existing code execution by ID, or add new one.
if (!message?.code_executions) {
message.code_executions = [];
}
const existingCodeExecutionIndex = message.code_executions.findIndex(
(execution) => execution.id === data.id
);
if (existingCodeExecutionIndex !== -1) {
message.code_executions[existingCodeExecutionIndex] = data;
} else {
message.code_executions.push(data);
}
message.code_executions = message.code_executions;
} else {
message.citations = [data];
// Regular citation.
if (message?.citations) {
message.citations.push(data);
} else {
message.citations = [data];
}
}
} else if (type === 'message') {
message.content += data.content;
@ -199,6 +228,20 @@
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
} else if (type === 'execute') {
eventCallback = cb;
try {
// Use Function constructor to evaluate code in a safer way
const asyncFunction = new Function(`return (async () => { ${data.code} })()`);
const result = await asyncFunction(); // Await the result of the async function
if (cb) {
cb(result);
}
} catch (error) {
console.error('Error executing code:', error);
}
} else if (type === 'input') {
eventCallback = cb;
@ -229,7 +272,7 @@
if (event.data.type === 'input:prompt') {
console.debug(event.data.text);
const inputElement = document.getElementById('chat-textarea');
const inputElement = document.getElementById('chat-input');
if (inputElement) {
prompt = event.data.text;
@ -276,14 +319,9 @@
if (controlPane && !$mobile) {
try {
if (value) {
const currentSize = controlPane.getSize();
if (currentSize === 0) {
const size = parseInt(localStorage?.chatControlsSize ?? '30');
controlPane.resize(size ? size : 30);
}
controlPaneComponent.openPane();
} else {
controlPane.resize(0);
controlPane.collapse();
}
} catch (e) {
// ignore
@ -293,10 +331,11 @@
if (!value) {
showCallOverlay.set(false);
showOverview.set(false);
showArtifacts.set(false);
}
});
const chatInput = document.getElementById('chat-textarea');
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
chats.subscribe(() => {});
@ -308,6 +347,74 @@
$socket?.off('chat-events');
});
// File upload functions
const uploadWeb = async (url) => {
console.log(url);
const fileItem = {
type: 'doc',
name: url,
collection_name: '',
status: 'uploading',
url: url,
error: ''
};
try {
files = [...files, fileItem];
const res = await processWeb(localStorage.token, '', url);
if (res) {
fileItem.status = 'uploaded';
fileItem.collection_name = res.collection_name;
fileItem.file = {
...res.file,
...fileItem.file
};
files = files;
}
} catch (e) {
// Remove the failed doc from the files array
files = files.filter((f) => f.name !== url);
toast.error(JSON.stringify(e));
}
};
const uploadYoutubeTranscription = async (url) => {
console.log(url);
const fileItem = {
type: 'doc',
name: url,
collection_name: '',
status: 'uploading',
context: 'full',
url: url,
error: ''
};
try {
files = [...files, fileItem];
const res = await processYoutubeVideo(localStorage.token, url);
if (res) {
fileItem.status = 'uploaded';
fileItem.collection_name = res.collection_name;
fileItem.file = {
...res.file,
...fileItem.file
};
files = files;
}
} catch (e) {
// Remove the failed doc from the files array
files = files.filter((f) => f.name !== url);
toast.error(e);
}
};
//////////////////////////
// Web functions
//////////////////////////
@ -338,16 +445,53 @@
if ($page.url.searchParams.get('models')) {
selectedModels = $page.url.searchParams.get('models')?.split(',');
} else if ($page.url.searchParams.get('model')) {
selectedModels = $page.url.searchParams.get('model')?.split(',');
const urlModels = $page.url.searchParams.get('model')?.split(',');
if (urlModels.length === 1) {
const m = $models.find((m) => m.id === urlModels[0]);
if (!m) {
const modelSelectorButton = document.getElementById('model-selector-0-button');
if (modelSelectorButton) {
modelSelectorButton.click();
await tick();
const modelSelectorInput = document.getElementById('model-search-input');
if (modelSelectorInput) {
modelSelectorInput.focus();
modelSelectorInput.value = urlModels[0];
modelSelectorInput.dispatchEvent(new Event('input'));
}
}
} else {
selectedModels = urlModels;
}
} else {
selectedModels = urlModels;
}
} else if ($settings?.models) {
selectedModels = $settings?.models;
} else if ($config?.default_models) {
console.log($config?.default_models.split(',') ?? '');
selectedModels = $config?.default_models.split(',');
} else {
selectedModels = [''];
}
selectedModels = selectedModels.filter((modelId) => $models.map((m) => m.id).includes(modelId));
if (selectedModels.length === 0 || (selectedModels.length === 1 && selectedModels[0] === '')) {
if ($models.length > 0) {
selectedModels = [$models[0].id];
} else {
selectedModels = [''];
}
}
console.log(selectedModels);
if ($page.url.searchParams.get('youtube')) {
uploadYoutubeTranscription(
`https://www.youtube.com/watch?v=${$page.url.searchParams.get('youtube')}`
);
}
if ($page.url.searchParams.get('web-search') === 'true') {
webSearchEnabled = true;
}
@ -366,6 +510,11 @@
.filter((id) => id);
}
if ($page.url.searchParams.get('call') === 'true') {
showCallOverlay.set(true);
showControls.set(true);
}
if ($page.url.searchParams.get('q')) {
prompt = $page.url.searchParams.get('q') ?? '';
@ -375,11 +524,6 @@
}
}
if ($page.url.searchParams.get('call') === 'true') {
showCallOverlay.set(true);
showControls.set(true);
}
selectedModels = selectedModels.map((modelId) =>
$models.map((m) => m.id).includes(modelId) ? modelId : ''
);
@ -392,7 +536,7 @@
settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
}
const chatInput = document.getElementById('chat-textarea');
const chatInput = document.getElementById('chat-input');
setTimeout(() => chatInput?.focus(), 0);
};
@ -404,7 +548,10 @@
});
if (chat) {
tags = await getTags();
tags = await getTagsById(localStorage.token, $chatId).catch(async (error) => {
return [];
});
const chatContent = chat.chat;
if (chatContent) {
@ -648,93 +795,106 @@
//////////////////////////
const submitPrompt = async (userPrompt, { _raw = false } = {}) => {
let _responses = [];
console.log('submitPrompt', $chatId);
const messages = createMessagesList(history.currentId);
console.log('submitPrompt', userPrompt, $chatId);
const messages = createMessagesList(history.currentId);
selectedModels = selectedModels.map((modelId) =>
$models.map((m) => m.id).includes(modelId) ? modelId : ''
);
if (userPrompt === '') {
toast.error($i18n.t('Please enter a prompt'));
return;
}
if (selectedModels.includes('')) {
toast.error($i18n.t('Model not selected'));
} else if (messages.length != 0 && messages.at(-1).done != true) {
return;
}
if (messages.length != 0 && messages.at(-1).done != true) {
// Response not done
console.log('wait');
} else if (messages.length != 0 && messages.at(-1).error) {
return;
}
if (messages.length != 0 && messages.at(-1).error) {
// Error in response
toast.error(
$i18n.t(
`Oops! There was an error in the previous response. Please try again or contact admin.`
)
);
} else if (
toast.error($i18n.t(`Oops! There was an error in the previous response.`));
return;
}
if (
files.length > 0 &&
files.filter((file) => file.type !== 'image' && file.status === 'uploading').length > 0
) {
// Upload not done
toast.error(
$i18n.t(
`Oops! Hold tight! Your files are still in the processing oven. We're cooking them up to perfection. Please be patient and we'll let you know once they're ready.`
)
$i18n.t(`Oops! There are files still uploading. Please wait for the upload to complete.`)
);
} else if (
return;
}
if (
($config?.file?.max_count ?? null) !== null &&
files.length + chatFiles.length > $config?.file?.max_count
) {
console.log(chatFiles.length, files.length);
toast.error(
$i18n.t(`You can only chat with a maximum of {{maxCount}} file(s) at a time.`, {
maxCount: $config?.file?.max_count
})
);
} else {
// Reset chat input textarea
const chatTextAreaElement = document.getElementById('chat-textarea');
if (chatTextAreaElement) {
chatTextAreaElement.value = '';
chatTextAreaElement.style.height = '';
}
const _files = JSON.parse(JSON.stringify(files));
chatFiles.push(..._files.filter((item) => ['doc', 'file', 'collection'].includes(item.type)));
chatFiles = chatFiles.filter(
// Remove duplicates
(item, index, array) =>
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
);
files = [];
prompt = '';
// Create user message
let userMessageId = uuidv4();
let userMessage = {
id: userMessageId,
parentId: messages.length !== 0 ? messages.at(-1).id : null,
childrenIds: [],
role: 'user',
content: userPrompt,
files: _files.length > 0 ? _files : undefined,
timestamp: Math.floor(Date.now() / 1000), // Unix epoch
models: selectedModels
};
// Add message to history and Set currentId to messageId
history.messages[userMessageId] = userMessage;
history.currentId = userMessageId;
// Append messageId to childrenIds of parent message
if (messages.length !== 0) {
history.messages[messages.at(-1).id].childrenIds.push(userMessageId);
}
// Wait until history/message have been updated
await tick();
_responses = await sendPrompt(userPrompt, userMessageId, { newChat: true });
return;
}
let _responses = [];
prompt = '';
await tick();
// Reset chat input textarea
const chatInputContainer = document.getElementById('chat-input-container');
if (chatInputContainer) {
chatInputContainer.value = '';
chatInputContainer.style.height = '';
}
const _files = JSON.parse(JSON.stringify(files));
chatFiles.push(..._files.filter((item) => ['doc', 'file', 'collection'].includes(item.type)));
chatFiles = chatFiles.filter(
// Remove duplicates
(item, index, array) =>
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
);
files = [];
prompt = '';
// Create user message
let userMessageId = uuidv4();
let userMessage = {
id: userMessageId,
parentId: messages.length !== 0 ? messages.at(-1).id : null,
childrenIds: [],
role: 'user',
content: userPrompt,
files: _files.length > 0 ? _files : undefined,
timestamp: Math.floor(Date.now() / 1000), // Unix epoch
models: selectedModels
};
// Add message to history and Set currentId to messageId
history.messages[userMessageId] = userMessage;
history.currentId = userMessageId;
// Append messageId to childrenIds of parent message
if (messages.length !== 0) {
history.messages[messages.at(-1).id].childrenIds.push(userMessageId);
}
// Wait until history/message have been updated
await tick();
// focus on chat input
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
_responses = await sendPrompt(userPrompt, userMessageId, { newChat: true });
return _responses;
};
@ -854,10 +1014,10 @@
}
let _response = null;
if (model?.owned_by === 'openai') {
_response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
} else if (model) {
if (model?.owned_by === 'ollama') {
_response = await sendPromptOllama(model, prompt, responseMessageId, _chatId);
} else if (model) {
_response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
}
_responses.push(_response);
@ -1253,8 +1413,13 @@
const messages = createMessagesList(responseMessageId);
if (messages.length == 2 && messages.at(-1).content !== '' && selectedModels[0] === model.id) {
window.history.replaceState(history.state, '', `/c/${_chatId}`);
const title = await generateChatTitle(userPrompt);
const title = await generateChatTitle(messages);
await setChatTitle(_chatId, title);
if ($settings?.autoTags ?? true) {
await setChatTags(messages);
}
}
return _response;
@ -1430,7 +1595,7 @@
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
for await (const update of textStream) {
const { value, done, citations, error, usage } = update;
const { value, done, citations, selectedModelId, error, usage } = update;
if (error) {
await handleOpenAIError(error, null, model, responseMessage);
break;
@ -1450,6 +1615,12 @@
responseMessage.info = { ...usage, openai: true, usage };
}
if (selectedModelId) {
responseMessage.selectedModelId = selectedModelId;
responseMessage.arena = true;
continue;
}
if (citations) {
responseMessage.citations = citations;
// Only remove status if it was initially set
@ -1567,8 +1738,13 @@
const messages = createMessagesList(responseMessageId);
if (messages.length == 2 && selectedModels[0] === model.id) {
window.history.replaceState(history.state, '', `/c/${_chatId}`);
const title = await generateChatTitle(userPrompt);
const title = await generateChatTitle(messages);
await setChatTitle(_chatId, title);
if ($settings?.autoTags ?? true) {
await setChatTags(messages);
}
}
return _response;
@ -1724,21 +1900,21 @@
}
};
const generateChatTitle = async (userPrompt) => {
const generateChatTitle = async (messages) => {
if ($settings?.title?.auto ?? true) {
const title = await generateTitle(
localStorage.token,
selectedModels[0],
userPrompt,
$chatId
).catch((error) => {
console.error(error);
return 'New Chat';
});
const lastMessage = messages.at(-1);
const modelId = selectedModels[0];
const title = await generateTitle(localStorage.token, modelId, messages, $chatId).catch(
(error) => {
console.error(error);
return 'New Chat';
}
);
return title;
} else {
return `${userPrompt}`;
return 'New Chat';
}
};
@ -1755,6 +1931,40 @@
}
};
const setChatTags = async (messages) => {
if (!$temporaryChatEnabled) {
const currentTags = await getTagsById(localStorage.token, $chatId);
if (currentTags.length > 0) {
const res = await deleteTagsById(localStorage.token, $chatId);
if (res) {
allTags.set(await getAllTags(localStorage.token));
}
}
const lastMessage = messages.at(-1);
const modelId = selectedModels[0];
let generatedTags = await generateTags(localStorage.token, modelId, messages, $chatId).catch(
(error) => {
console.error(error);
return [];
}
);
generatedTags = generatedTags.filter(
(tag) => !currentTags.find((t) => t.id === tag.replaceAll(' ', '_').toLowerCase())
);
console.log(generatedTags);
for (const tag of generatedTags) {
await addTagById(localStorage.token, $chatId, tag);
}
chat = await getChatById(localStorage.token, $chatId);
allTags.set(await getAllTags(localStorage.token));
}
};
const getWebSearchResults = async (
model: string,
parentId: string,
@ -1840,12 +2050,6 @@
}
};
const getTags = async () => {
return await getTagsById(localStorage.token, $chatId).catch(async (error) => {
return [];
});
};
const initChatHandler = async () => {
if (!$temporaryChatEnabled) {
chat = await createNewChat(localStorage.token, {
@ -1855,6 +2059,7 @@
system: $settings.system ?? undefined,
params: params,
history: history,
messages: createMessagesList(history.currentId),
tags: [],
timestamp: Date.now()
});
@ -1920,6 +2125,7 @@
class="h-screen max-h-[100dvh] {$showSidebar
? 'md:max-w-[calc(100%-260px)]'
: ''} w-full max-w-full flex flex-col"
id="chat-container"
>
{#if $settings?.backgroundImageUrl ?? null}
<div
@ -1935,7 +2141,18 @@
{/if}
<Navbar
{chat}
bind:this={navbarElement}
chat={{
id: $chatId,
chat: {
title: $chatTitle,
models: selectedModels,
system: $settings.system ?? undefined,
params: params,
history: history,
timestamp: Date.now()
}
}}
title={$chatTitle}
bind:selectedModels
shareEnabled={!!history.currentId}
@ -2050,11 +2267,19 @@
transparentBackground={$settings?.backgroundImageUrl ?? false}
{stopResponse}
{createMessagePair}
on:upload={async (e) => {
const { type, data } = e.detail;
if (type === 'web') {
await uploadWeb(data);
} else if (type === 'youtube') {
await uploadYoutubeTranscription(data);
}
}}
on:submit={async (e) => {
if (e.detail) {
prompt = '';
await tick();
submitPrompt(e.detail);
submitPrompt(e.detail.replaceAll('\n\n', '\n'));
}
}}
/>
@ -2066,38 +2291,49 @@
</div>
</div>
{:else}
<Placeholder
{history}
{selectedModels}
bind:files
bind:prompt
bind:autoScroll
bind:selectedToolIds
bind:webSearchEnabled
bind:atSelectedModel
availableToolIds={selectedModelIds.reduce((a, e, i, arr) => {
const model = $models.find((m) => m.id === e);
if (model?.info?.meta?.toolIds ?? false) {
return [...new Set([...a, ...model.info.meta.toolIds])];
}
return a;
}, [])}
transparentBackground={$settings?.backgroundImageUrl ?? false}
{stopResponse}
{createMessagePair}
on:submit={async (e) => {
if (e.detail) {
prompt = '';
await tick();
submitPrompt(e.detail);
}
}}
/>
<div class="overflow-auto w-full h-full flex items-center">
<Placeholder
{history}
{selectedModels}
bind:files
bind:prompt
bind:autoScroll
bind:selectedToolIds
bind:webSearchEnabled
bind:atSelectedModel
availableToolIds={selectedModelIds.reduce((a, e, i, arr) => {
const model = $models.find((m) => m.id === e);
if (model?.info?.meta?.toolIds ?? false) {
return [...new Set([...a, ...model.info.meta.toolIds])];
}
return a;
}, [])}
transparentBackground={$settings?.backgroundImageUrl ?? false}
{stopResponse}
{createMessagePair}
on:upload={async (e) => {
const { type, data } = e.detail;
if (type === 'web') {
await uploadWeb(data);
} else if (type === 'youtube') {
await uploadYoutubeTranscription(data);
}
}}
on:submit={async (e) => {
if (e.detail) {
await tick();
submitPrompt(e.detail.replaceAll('\n\n', '\n'));
}
}}
/>
</div>
{/if}
</div>
</Pane>
<ChatControls
bind:this={controlPaneComponent}
bind:history
bind:chatFiles
bind:params

View file

@ -1,6 +1,7 @@
<script lang="ts">
import { SvelteFlowProvider } from '@xyflow/svelte';
import { slide } from 'svelte/transition';
import { Pane, PaneResizer } from 'paneforge';
import { onDestroy, onMount, tick } from 'svelte';
import { mobile, showControls, showCallOverlay, showOverview, showArtifacts } from '$lib/stores';
@ -10,9 +11,9 @@
import CallOverlay from './MessageInput/CallOverlay.svelte';
import Drawer from '../common/Drawer.svelte';
import Overview from './Overview.svelte';
import { Pane, PaneResizer } from 'paneforge';
import EllipsisVertical from '../icons/EllipsisVertical.svelte';
import Artifacts from './Artifacts.svelte';
import { min } from '@floating-ui/utils';
export let history;
export let models = [];
@ -35,6 +36,16 @@
let largeScreen = false;
let dragged = false;
let minSize = 0;
export const openPane = () => {
if (parseInt(localStorage?.chatControlsSize)) {
pane.resize(parseInt(localStorage?.chatControlsSize));
} else {
pane.resize(minSize);
}
};
const handleMediaQuery = async (e) => {
if (e.matches) {
largeScreen = true;
@ -71,6 +82,32 @@
mediaQuery.addEventListener('change', handleMediaQuery);
handleMediaQuery(mediaQuery);
// Select the container element you want to observe
const container = document.getElementById('chat-container');
// initialize the minSize based on the container width
minSize = Math.floor((350 / container.clientWidth) * 100);
// Create a new ResizeObserver instance
const resizeObserver = new ResizeObserver((entries) => {
for (let entry of entries) {
const width = entry.contentRect.width;
// calculate the percentage of 200px
const percentage = (350 / width) * 100;
// set the minSize to the percentage, must be an integer
minSize = Math.floor(percentage);
if ($showControls) {
if (pane && pane.isExpanded() && pane.getSize() < minSize) {
pane.resize(minSize);
}
}
}
});
// Start observing the container's size changes
resizeObserver.observe(container);
document.addEventListener('mousedown', onMouseDown);
document.addEventListener('mouseup', onMouseUp);
});
@ -163,23 +200,29 @@
</div>
</PaneResizer>
{/if}
<Pane
bind:pane
defaultSize={$showControls
? parseInt(localStorage?.chatControlsSize ?? '30')
? parseInt(localStorage?.chatControlsSize ?? '30')
: 30
: 0}
defaultSize={0}
onResize={(size) => {
if (size === 0) {
showControls.set(false);
} else {
if (!$showControls) {
showControls.set(true);
console.log('size', size, minSize);
if ($showControls && pane.isExpanded()) {
if (size < minSize) {
pane.resize(minSize);
}
if (size < minSize) {
localStorage.chatControlsSize = 0;
} else {
localStorage.chatControlsSize = size;
}
localStorage.chatControlsSize = size;
}
}}
onCollapse={() => {
showControls.set(false);
}}
collapsible={true}
class="pt-8"
>
{#if $showControls}
@ -187,7 +230,7 @@
<div
class="w-full {($showOverview || $showArtifacts) && !$showCallOverlay
? ' '
: 'px-5 py-4 bg-white dark:shadow-lg dark:bg-gray-850 border border-gray-50 dark:border-gray-800'} rounded-lg z-40 pointer-events-auto overflow-y-auto scrollbar-hidden"
: 'px-4 py-4 bg-white dark:shadow-lg dark:bg-gray-850 border border-gray-50 dark:border-gray-850'} rounded-xl z-40 pointer-events-auto overflow-y-auto scrollbar-hidden"
>
{#if $showCallOverlay}
<div class="w-full h-full flex justify-center">

View file

@ -82,8 +82,8 @@
>
<div>
<div class=" capitalize line-clamp-1" in:fade={{ duration: 200 }}>
{#if models[selectedModelIdx]?.info}
{models[selectedModelIdx]?.info?.name}
{#if models[selectedModelIdx]?.name}
{models[selectedModelIdx]?.name}
{:else}
{$i18n.t('Hello, {{name}}', { name: $user.name })}
{/if}

View file

@ -16,7 +16,7 @@
</script>
<div class=" dark:text-white">
<div class=" flex justify-between dark:text-gray-100 mb-2">
<div class=" flex items-center justify-between dark:text-gray-100 mb-2">
<div class=" text-lg font-medium self-center font-primary">{$i18n.t('Chat Controls')}</div>
<button
class="self-center"
@ -24,13 +24,13 @@
dispatch('close');
}}
>
<XMark className="size-4" />
<XMark className="size-3.5" />
</button>
</div>
<div class=" dark:text-gray-200 text-sm font-primary py-0.5">
<div class=" dark:text-gray-200 text-sm font-primary py-0.5 px-0.5">
{#if chatFiles.length > 0}
<Collapsible title={$i18n.t('Files')} open={true}>
<Collapsible title={$i18n.t('Files')} open={true} buttonClassName="w-full">
<div class="flex flex-col gap-1 mt-1.5" slot="content">
{#each chatFiles as file, fileIdx}
<FileItem
@ -56,31 +56,31 @@
</div>
</Collapsible>
<hr class="my-2 border-gray-100 dark:border-gray-800" />
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
{/if}
<Collapsible title={$i18n.t('Valves')}>
<div class="text-sm mt-1.5" slot="content">
<Collapsible title={$i18n.t('Valves')} buttonClassName="w-full">
<div class="text-sm" slot="content">
<Valves />
</div>
</Collapsible>
<hr class="my-2 border-gray-100 dark:border-gray-800" />
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
<Collapsible title={$i18n.t('System Prompt')} open={true}>
<div class=" mt-1.5" slot="content">
<Collapsible title={$i18n.t('System Prompt')} open={true} buttonClassName="w-full">
<div class="" slot="content">
<textarea
bind:value={params.system}
class="w-full rounded-lg px-3.5 py-2.5 text-sm dark:text-gray-300 dark:bg-gray-850 border border-gray-100 dark:border-gray-800 outline-none resize-none"
class="w-full text-xs py-1.5 bg-transparent outline-none resize-none"
rows="4"
placeholder={$i18n.t('Enter system prompt')}
/>
</div>
</Collapsible>
<hr class="my-2 border-gray-100 dark:border-gray-800" />
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
<Collapsible title={$i18n.t('Advanced Params')} open={true}>
<Collapsible title={$i18n.t('Advanced Params')} open={true} buttonClassName="w-full">
<div class="text-sm mt-1.5" slot="content">
<div>
<AdvancedParams admin={$user?.role === 'admin'} bind:params />

View file

@ -1,6 +1,6 @@
<script lang="ts">
import { toast } from 'svelte-sonner';
import { onMount, tick, getContext, createEventDispatcher } from 'svelte';
import { onMount, tick, getContext, createEventDispatcher, onDestroy } from 'svelte';
const dispatch = createEventDispatcher();
import {
@ -29,6 +29,7 @@
import FilesOverlay from './MessageInput/FilesOverlay.svelte';
import Commands from './MessageInput/Commands.svelte';
import XMark from '../icons/XMark.svelte';
import RichTextInput from '../common/RichTextInput.svelte';
const i18n = getContext('i18n');
@ -52,9 +53,10 @@
let recording = false;
let chatTextAreaElement: HTMLTextAreaElement;
let filesInputElement;
let chatInputContainerElement;
let chatInputElement;
let filesInputElement;
let commandsElement;
let inputFiles;
@ -69,9 +71,10 @@
);
$: if (prompt) {
if (chatTextAreaElement) {
chatTextAreaElement.style.height = '';
chatTextAreaElement.style.height = Math.min(chatTextAreaElement.scrollHeight, 200) + 'px';
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
}
}
@ -175,57 +178,66 @@
});
};
onMount(() => {
window.setTimeout(() => chatTextAreaElement?.focus(), 0);
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
console.log('Escape');
dragged = false;
}
};
const dropZone = document.querySelector('body');
const onDragOver = (e) => {
e.preventDefault();
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
console.log('Escape');
dragged = false;
}
};
const onDragOver = (e) => {
e.preventDefault();
// Check if a file is being dragged.
if (e.dataTransfer?.types?.includes('Files')) {
dragged = true;
};
const onDragLeave = () => {
} else {
dragged = false;
};
}
};
const onDrop = async (e) => {
e.preventDefault();
console.log(e);
const onDragLeave = () => {
dragged = false;
};
if (e.dataTransfer?.files) {
const inputFiles = Array.from(e.dataTransfer?.files);
if (inputFiles && inputFiles.length > 0) {
console.log(inputFiles);
inputFilesHandler(inputFiles);
} else {
toast.error($i18n.t(`File not found.`));
}
const onDrop = async (e) => {
e.preventDefault();
console.log(e);
if (e.dataTransfer?.files) {
const inputFiles = Array.from(e.dataTransfer?.files);
if (inputFiles && inputFiles.length > 0) {
console.log(inputFiles);
inputFilesHandler(inputFiles);
}
}
dragged = false;
};
dragged = false;
};
onMount(() => {
window.setTimeout(() => {
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
}, 0);
window.addEventListener('keydown', handleKeyDown);
const dropZone = document.getElementById('chat-container');
dropZone?.addEventListener('dragover', onDragOver);
dropZone?.addEventListener('drop', onDrop);
dropZone?.addEventListener('dragleave', onDragLeave);
});
return () => {
window.removeEventListener('keydown', handleKeyDown);
onDestroy(() => {
window.removeEventListener('keydown', handleKeyDown);
dropZone?.removeEventListener('dragover', onDragOver);
dropZone?.removeEventListener('drop', onDrop);
dropZone?.removeEventListener('dragleave', onDragLeave);
};
const dropZone = document.getElementById('chat-container');
dropZone?.removeEventListener('dragover', onDragOver);
dropZone?.removeEventListener('drop', onDrop);
dropZone?.removeEventListener('dragleave', onDragLeave);
});
</script>
@ -300,6 +312,9 @@
bind:this={commandsElement}
bind:prompt
bind:files
on:upload={(e) => {
dispatch('upload', e.detail);
}}
on:select={(e) => {
const data = e.detail;
@ -307,7 +322,8 @@
atSelectedModel = data.data;
}
chatTextAreaElement?.focus();
const chatInputElement = document.getElementById('chat-input');
chatInputElement?.focus();
}}
/>
</div>
@ -342,7 +358,7 @@
recording = false;
await tick();
document.getElementById('chat-textarea')?.focus();
document.getElementById('chat-input')?.focus();
}}
on:confirm={async (e) => {
const response = e.detail;
@ -351,7 +367,7 @@
recording = false;
await tick();
document.getElementById('chat-textarea')?.focus();
document.getElementById('chat-input')?.focus();
if ($settings?.speechAutoSend ?? false) {
dispatch('submit', prompt);
@ -469,7 +485,9 @@
}}
onClose={async () => {
await tick();
chatTextAreaElement?.focus();
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
}}
>
<button
@ -491,177 +509,173 @@
</InputMenu>
</div>
<textarea
id="chat-textarea"
bind:this={chatTextAreaElement}
class="scrollbar-hidden bg-gray-50 dark:bg-gray-850 dark:text-gray-100 outline-none w-full py-3 px-1 rounded-xl resize-none h-[48px]"
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
bind:value={prompt}
on:keypress={(e) => {
if (
!$mobile ||
<div
bind:this={chatInputContainerElement}
id="chat-input-container"
class="scrollbar-hidden text-left bg-gray-50 dark:bg-gray-850 dark:text-gray-100 outline-none w-full py-2.5 px-1 rounded-xl resize-none h-[48px] overflow-auto"
>
<RichTextInput
bind:this={chatInputElement}
id="chat-input"
trim={true}
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
bind:value={prompt}
shiftEnter={!$mobile ||
!(
'ontouchstart' in window ||
navigator.maxTouchPoints > 0 ||
navigator.msMaxTouchPoints > 0
)
) {
// Prevent Enter key from creating a new line
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
}
// Submit the prompt when Enter key is pressed
if (prompt !== '' && e.key === 'Enter' && !e.shiftKey) {
)}
on:enter={async (e) => {
if (prompt !== '') {
dispatch('submit', prompt);
}
}
}}
on:keydown={async (e) => {
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
const commandsContainerElement = document.getElementById('commands-container');
// Command/Ctrl + Shift + Enter to submit a message pair
if (isCtrlPressed && e.key === 'Enter' && e.shiftKey) {
e.preventDefault();
createMessagePair(prompt);
}
// Check if Ctrl + R is pressed
if (prompt === '' && isCtrlPressed && e.key.toLowerCase() === 'r') {
e.preventDefault();
console.log('regenerate');
const regenerateButton = [
...document.getElementsByClassName('regenerate-response-button')
]?.at(-1);
regenerateButton?.click();
}
if (prompt === '' && e.key == 'ArrowUp') {
e.preventDefault();
const userMessageElement = [
...document.getElementsByClassName('user-message')
]?.at(-1);
const editButton = [
...document.getElementsByClassName('edit-user-message-button')
]?.at(-1);
console.log(userMessageElement);
userMessageElement.scrollIntoView({ block: 'center' });
editButton?.click();
}
if (commandsContainerElement && e.key === 'ArrowUp') {
e.preventDefault();
commandsElement.selectUp();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton.scrollIntoView({ block: 'center' });
}
if (commandsContainerElement && e.key === 'ArrowDown') {
e.preventDefault();
commandsElement.selectDown();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton.scrollIntoView({ block: 'center' });
}
if (commandsContainerElement && e.key === 'Enter') {
e.preventDefault();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
if (e.shiftKey) {
prompt = `${prompt}\n`;
} else if (commandOptionButton) {
commandOptionButton?.click();
} else {
document.getElementById('send-message-button')?.click();
}}
on:input={async (e) => {
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
}
}
}}
on:focus={async (e) => {
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
}
}}
on:keypress={(e) => {
e = e.detail.event;
}}
on:keydown={async (e) => {
e = e.detail.event;
if (commandsContainerElement && e.key === 'Tab') {
e.preventDefault();
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
}
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton?.click();
} else if (e.key === 'Tab') {
const words = findWordIndices(prompt);
if (words.length > 0) {
const word = words.at(0);
const fullPrompt = prompt;
prompt = prompt.substring(0, word?.endIndex + 1);
await tick();
e.target.scrollTop = e.target.scrollHeight;
prompt = fullPrompt;
await tick();
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
const commandsContainerElement =
document.getElementById('commands-container');
// Command/Ctrl + Shift + Enter to submit a message pair
if (isCtrlPressed && e.key === 'Enter' && e.shiftKey) {
e.preventDefault();
e.target.setSelectionRange(word?.startIndex, word.endIndex + 1);
createMessagePair(prompt);
}
e.target.style.height = '';
e.target.style.height = Math.min(e.target.scrollHeight, 200) + 'px';
}
// Check if Ctrl + R is pressed
if (prompt === '' && isCtrlPressed && e.key.toLowerCase() === 'r') {
e.preventDefault();
console.log('regenerate');
if (e.key === 'Escape') {
console.log('Escape');
atSelectedModel = undefined;
}
}}
rows="1"
on:input={async (e) => {
e.target.style.height = '';
e.target.style.height = Math.min(e.target.scrollHeight, 200) + 'px';
user = null;
}}
on:focus={async (e) => {
e.target.style.height = '';
e.target.style.height = Math.min(e.target.scrollHeight, 200) + 'px';
}}
on:paste={async (e) => {
const clipboardData = e.clipboardData || window.clipboardData;
const regenerateButton = [
...document.getElementsByClassName('regenerate-response-button')
]?.at(-1);
if (clipboardData && clipboardData.items) {
for (const item of clipboardData.items) {
if (item.type.indexOf('image') !== -1) {
const blob = item.getAsFile();
const reader = new FileReader();
regenerateButton?.click();
}
reader.onload = function (e) {
files = [
...files,
{
type: 'image',
url: `${e.target.result}`
}
];
};
if (prompt === '' && e.key == 'ArrowUp') {
e.preventDefault();
reader.readAsDataURL(blob);
const userMessageElement = [
...document.getElementsByClassName('user-message')
]?.at(-1);
const editButton = [
...document.getElementsByClassName('edit-user-message-button')
]?.at(-1);
console.log(userMessageElement);
userMessageElement.scrollIntoView({ block: 'center' });
editButton?.click();
}
if (commandsContainerElement && e.key === 'ArrowUp') {
e.preventDefault();
commandsElement.selectUp();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton.scrollIntoView({ block: 'center' });
}
if (commandsContainerElement && e.key === 'ArrowDown') {
e.preventDefault();
commandsElement.selectDown();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton.scrollIntoView({ block: 'center' });
}
if (commandsContainerElement && e.key === 'Enter') {
e.preventDefault();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
if (e.shiftKey) {
prompt = `${prompt}\n`;
} else if (commandOptionButton) {
commandOptionButton?.click();
} else {
document.getElementById('send-message-button')?.click();
}
}
}
}}
/>
if (commandsContainerElement && e.key === 'Tab') {
e.preventDefault();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton?.click();
}
if (e.key === 'Escape') {
console.log('Escape');
atSelectedModel = undefined;
}
}}
on:paste={async (e) => {
e = e.detail.event;
console.log(e);
const clipboardData = e.clipboardData || window.clipboardData;
if (clipboardData && clipboardData.items) {
for (const item of clipboardData.items) {
if (item.type.indexOf('image') !== -1) {
const blob = item.getAsFile();
const reader = new FileReader();
reader.onload = function (e) {
files = [
...files,
{
type: 'image',
url: `${e.target.result}`
}
];
};
reader.readAsDataURL(blob);
}
}
}
}}
/>
</div>
<div class="self-end mb-2 flex space-x-1 mr-1">
{#if !history?.currentId || history.messages[history.currentId]?.done == true}
@ -791,25 +805,27 @@
{/if}
{:else}
<div class=" flex items-center mb-1.5">
<button
class="bg-white hover:bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-1.5"
on:click={() => {
stopResponse();
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="size-6"
<Tooltip content={$i18n.t('Stop')}>
<button
class="bg-white hover:bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-1.5"
on:click={() => {
stopResponse();
}}
>
<path
fill-rule="evenodd"
d="M2.25 12c0-5.385 4.365-9.75 9.75-9.75s9.75 4.365 9.75 9.75-4.365 9.75-9.75 9.75S2.25 17.385 2.25 12zm6-2.438c0-.724.588-1.312 1.313-1.312h4.874c.725 0 1.313.588 1.313 1.313v4.874c0 .725-.588 1.313-1.313 1.313H9.564a1.312 1.312 0 01-1.313-1.313V9.564z"
clip-rule="evenodd"
/>
</svg>
</button>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="size-6"
>
<path
fill-rule="evenodd"
d="M2.25 12c0-5.385 4.365-9.75 9.75-9.75s9.75 4.365 9.75 9.75-4.365 9.75-9.75 9.75S2.25 17.385 2.25 12zm6-2.438c0-.724.588-1.312 1.313-1.312h4.874c.725 0 1.313.588 1.313 1.313v4.874c0 .725-.588 1.313-1.313 1.313H9.564a1.312 1.312 0 01-1.313-1.313V9.564z"
clip-rule="evenodd"
/>
</svg>
</button>
</Tooltip>
</div>
{/if}
</div>

View file

@ -25,96 +25,36 @@
};
let command = '';
$: command = (prompt?.trim() ?? '').split(' ')?.at(-1) ?? '';
const uploadWeb = async (url) => {
console.log(url);
const fileItem = {
type: 'doc',
name: url,
collection_name: '',
status: 'uploading',
url: url,
error: ''
};
try {
files = [...files, fileItem];
const res = await processWeb(localStorage.token, '', url);
if (res) {
fileItem.status = 'uploaded';
fileItem.collection_name = res.collection_name;
fileItem.file = {
...res.file,
...fileItem.file
};
files = files;
}
} catch (e) {
// Remove the failed doc from the files array
files = files.filter((f) => f.name !== url);
toast.error(JSON.stringify(e));
}
};
const uploadYoutubeTranscription = async (url) => {
console.log(url);
const fileItem = {
type: 'doc',
name: url,
collection_name: '',
status: 'uploading',
url: url,
error: ''
};
try {
files = [...files, fileItem];
const res = await processYoutubeVideo(localStorage.token, url);
if (res) {
fileItem.status = 'uploaded';
fileItem.collection_name = res.collection_name;
fileItem.file = {
...res.file,
...fileItem.file
};
files = files;
}
} catch (e) {
// Remove the failed doc from the files array
files = files.filter((f) => f.name !== url);
toast.error(e);
}
};
$: command = prompt?.split('\n').pop()?.split(' ')?.pop() ?? '';
</script>
{#if ['/', '#', '@'].includes(command?.charAt(0))}
{#if ['/', '#', '@'].includes(command?.charAt(0)) || '\\#' === command.slice(0, 2)}
{#if command?.charAt(0) === '/'}
<Prompts bind:this={commandElement} bind:prompt bind:files {command} />
{:else if command?.charAt(0) === '#'}
{:else if (command?.charAt(0) === '#' && command.startsWith('#') && !command.includes('# ')) || ('\\#' === command.slice(0, 2) && command.startsWith('#') && !command.includes('# '))}
<Knowledge
bind:this={commandElement}
bind:prompt
{command}
command={command.includes('\\#') ? command.slice(2) : command}
on:youtube={(e) => {
console.log(e);
uploadYoutubeTranscription(e.detail);
dispatch('upload', {
type: 'youtube',
data: e.detail
});
}}
on:url={(e) => {
console.log(e);
uploadWeb(e.detail);
dispatch('upload', {
type: 'web',
data: e.detail
});
}}
on:select={(e) => {
console.log(e);
files = [
...files,
{
type: e?.detail?.meta?.document ? 'file' : 'collection',
...e.detail,
status: 'processed'
}

View file

@ -2,6 +2,10 @@
import { toast } from 'svelte-sonner';
import Fuse from 'fuse.js';
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
dayjs.extend(relativeTime);
import { createEventDispatcher, tick, getContext, onMount } from 'svelte';
import { removeLastWordFromString, isValidHttpUrl } from '$lib/utils';
import { knowledge } from '$lib/stores';
@ -42,7 +46,7 @@
dispatch('select', item);
prompt = removeLastWordFromString(prompt, command);
const chatInputElement = document.getElementById('chat-textarea');
const chatInputElement = document.getElementById('chat-input');
await tick();
chatInputElement?.focus();
@ -53,7 +57,7 @@
dispatch('url', url);
prompt = removeLastWordFromString(prompt, command);
const chatInputElement = document.getElementById('chat-textarea');
const chatInputElement = document.getElementById('chat-input');
await tick();
chatInputElement?.focus();
@ -64,7 +68,7 @@
dispatch('youtube', url);
prompt = removeLastWordFromString(prompt, command);
const chatInputElement = document.getElementById('chat-textarea');
const chatInputElement = document.getElementById('chat-input');
await tick();
chatInputElement?.focus();
@ -72,7 +76,13 @@
};
onMount(() => {
let legacy_documents = $knowledge.filter((item) => item?.meta?.document);
let legacy_documents = $knowledge
.filter((item) => item?.meta?.document)
.map((item) => ({
...item,
type: 'file'
}));
let legacy_collections =
legacy_documents.length > 0
? [
@ -101,12 +111,44 @@
]
: [];
items = [...$knowledge, ...legacy_collections].map((item) => {
return {
let collections = $knowledge
.filter((item) => !item?.meta?.document)
.map((item) => ({
...item,
...(item?.legacy || item?.meta?.legacy || item?.meta?.document ? { legacy: true } : {})
};
});
type: 'collection'
}));
let collection_files =
$knowledge.length > 0
? [
...$knowledge
.reduce((a, item) => {
return [
...new Set([
...a,
...(item?.files ?? []).map((file) => ({
...file,
collection: { name: item.name, description: item.description }
}))
])
];
}, [])
.map((file) => ({
...file,
name: file?.meta?.name,
description: `${file?.collection?.name} - ${file?.collection?.description}`,
type: 'file'
}))
]
: [];
items = [...collections, ...collection_files, ...legacy_collections, ...legacy_documents].map(
(item) => {
return {
...item,
...(item?.legacy || item?.meta?.legacy || item?.meta?.document ? { legacy: true } : {})
};
}
);
fuse = new Fuse(items, {
keys: ['name', 'description']
@ -117,20 +159,17 @@
{#if filteredItems.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
<div
id="commands-container"
class="pl-2 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
class="pl-3 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
>
<div class="flex w-full dark:border dark:border-gray-850 rounded-lg">
<div class=" bg-gray-50 dark:bg-gray-850 w-10 rounded-l-lg text-center">
<div class=" text-lg font-medium mt-2">#</div>
</div>
<div class="flex w-full rounded-xl border border-gray-50 dark:border-gray-850">
<div
class="max-h-60 flex flex-col w-full rounded-r-xl bg-white dark:bg-gray-900 dark:text-gray-100"
class="max-h-60 flex flex-col w-full rounded-xl bg-white dark:bg-gray-900 dark:text-gray-100"
>
<div class="m-1 overflow-y-auto p-1 rounded-r-xl space-y-0.5 scrollbar-hidden">
{#each filteredItems as item, idx}
<button
class=" px-3 py-1.5 rounded-xl w-full text-left {idx === selectedIdx
class=" px-3 py-1.5 rounded-xl w-full text-left flex justify-between items-center {idx ===
selectedIdx
? ' bg-gray-50 dark:bg-gray-850 dark:text-gray-100 selected-command-option-button'
: ''}"
type="button"
@ -141,38 +180,87 @@
on:mousemove={() => {
selectedIdx = idx;
}}
on:focus={() => {}}
>
<div class=" font-medium text-black dark:text-gray-100 flex items-center gap-1">
{#if item.legacy}
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1"
>
Legacy
</div>
{:else if item?.meta?.document}
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1"
>
Document
</div>
{:else}
<div
class="bg-green-500/20 text-green-700 dark:text-green-200 rounded uppercase text-xs font-bold px-1"
>
Collection
</div>
{/if}
<div>
<div class=" font-medium text-black dark:text-gray-100 flex items-center gap-1">
{#if item.legacy}
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
>
Legacy
</div>
{:else if item?.meta?.document}
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
>
Document
</div>
{:else if item?.type === 'file'}
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
>
File
</div>
{:else}
<div
class="bg-green-500/20 text-green-700 dark:text-green-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
>
Collection
</div>
{/if}
<div class="line-clamp-1">
{item.name}
<div class="line-clamp-1">
{item?.name}
</div>
</div>
<div class=" text-xs text-gray-600 dark:text-gray-100 line-clamp-1">
{item?.description}
</div>
</div>
<div class=" text-xs text-gray-600 dark:text-gray-100 line-clamp-1">
{item?.description}
</div>
</button>
<!-- <div slot="content" class=" pl-2 pt-1 flex flex-col gap-0.5">
{#if !item.legacy && (item?.files ?? []).length > 0}
{#each item?.files ?? [] as file, fileIdx}
<button
class=" px-3 py-1.5 rounded-xl w-full text-left flex justify-between items-center hover:bg-gray-50 dark:hover:bg-gray-850 dark:hover:text-gray-100 selected-command-option-button"
type="button"
on:click={() => {
console.log(file);
}}
on:mousemove={() => {
selectedIdx = idx;
}}
>
<div>
<div
class=" font-medium text-black dark:text-gray-100 flex items-center gap-1"
>
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
>
File
</div>
<div class="line-clamp-1">
{file?.meta?.name}
</div>
</div>
<div class=" text-xs text-gray-600 dark:text-gray-100 line-clamp-1">
{$i18n.t('Updated')}
{dayjs(file.updated_at * 1000).fromNow()}
</div>
</div>
</button>
{/each}
{:else}
<div class=" text-gray-500 text-xs mt-1 mb-2">
{$i18n.t('No files found.')}
</div>
{/if}
</div> -->
{/each}
{#if prompt

View file

@ -58,7 +58,7 @@
onMount(async () => {
await tick();
const chatInputElement = document.getElementById('chat-textarea');
const chatInputElement = document.getElementById('chat-input');
await tick();
chatInputElement?.focus();
await tick();
@ -68,15 +68,11 @@
{#if filteredItems.length > 0}
<div
id="commands-container"
class="pl-2 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
class="pl-3 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
>
<div class="flex w-full dark:border dark:border-gray-850 rounded-lg">
<div class=" bg-gray-50 dark:bg-gray-850 w-10 rounded-l-lg text-center">
<div class=" text-lg font-medium mt-2">@</div>
</div>
<div class="flex w-full rounded-xl border border-gray-50 dark:border-gray-850">
<div
class="max-h-60 flex flex-col w-full rounded-r-lg bg-white dark:bg-gray-900 dark:text-gray-100"
class="max-h-60 flex flex-col w-full rounded-xl bg-white dark:bg-gray-900 dark:text-gray-100"
>
<div class="m-1 overflow-y-auto p-1 rounded-r-lg space-y-0.5 scrollbar-hidden">
{#each filteredItems as model, modelIdx}

View file

@ -110,21 +110,17 @@
prompt = text;
const chatInputElement = document.getElementById('chat-textarea');
const chatInputContainerElement = document.getElementById('chat-input-container');
const chatInputElement = document.getElementById('chat-input');
await tick();
chatInputElement.style.height = '';
chatInputElement.style.height = Math.min(chatInputElement.scrollHeight, 200) + 'px';
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
chatInputElement?.focus();
await tick();
const words = findWordIndices(prompt);
if (words.length > 0) {
const word = words.at(0);
chatInputElement.setSelectionRange(word?.startIndex, word.endIndex + 1);
chatInputElement?.focus();
}
};
</script>
@ -132,17 +128,13 @@
{#if filteredPrompts.length > 0}
<div
id="commands-container"
class="pl-2 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
class="pl-3 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
>
<div class="flex w-full dark:border dark:border-gray-850 rounded-lg">
<div class=" bg-gray-50 dark:bg-gray-850 w-10 rounded-l-lg text-center">
<div class=" text-lg font-medium mt-2">/</div>
</div>
<div class="flex w-full rounded-xl border border-gray-50 dark:border-gray-850">
<div
class="max-h-60 flex flex-col w-full rounded-r-lg bg-white dark:bg-gray-900 dark:text-gray-100"
class="max-h-60 flex flex-col w-full rounded-xl bg-white dark:bg-gray-900 dark:text-gray-100"
>
<div class="m-1 overflow-y-auto p-1 rounded-r-lg space-y-0.5 scrollbar-hidden">
<div class="m-1 overflow-y-auto p-1 space-y-0.5 scrollbar-hidden">
{#each filteredPrompts as prompt, promptIdx}
<button
class=" px-3 py-1.5 rounded-xl w-full text-left {promptIdx === selectedPromptIdx
@ -169,7 +161,7 @@
</div>
<div
class=" px-2 pb-1 text-xs text-gray-600 dark:text-gray-100 bg-white dark:bg-gray-900 rounded-br-xl flex items-center space-x-1"
class=" px-2 pt-0.5 pb-1 text-xs text-gray-600 dark:text-gray-100 bg-white dark:bg-gray-900 rounded-b-xl flex items-center space-x-1"
>
<div>
<svg

View file

@ -61,9 +61,14 @@
<div
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
>
<div class="flex-1 flex items-center gap-2">
<WrenchSolid />
<Tooltip content={tools[toolId]?.description ?? ''} className="flex-1">
<div class="flex-1">
<Tooltip
content={tools[toolId]?.description ?? ''}
placement="top-start"
className="flex flex-1 gap-2 items-center"
>
<WrenchSolid />
<div class=" line-clamp-1">{tools[toolId].name}</div>
</Tooltip>
</div>

View file

@ -11,6 +11,7 @@
const dispatch = createEventDispatcher();
export let recording = false;
export let className = ' p-2.5 w-full max-w-full';
let loading = false;
let confirmed = false;
@ -213,7 +214,7 @@
transcription = `${transcription}${transcript}`;
await tick();
document.getElementById('chat-textarea')?.focus();
document.getElementById('chat-input')?.focus();
// Restart the inactivity timeout
timeoutId = setTimeout(() => {
@ -269,13 +270,20 @@
await mediaRecorder.stop();
}
clearInterval(durationCounter);
if (stream) {
const tracks = stream.getTracks();
tracks.forEach((track) => track.stop());
}
stream = null;
};
</script>
<div
class="{loading
? ' bg-gray-100/50 dark:bg-gray-850/50'
: 'bg-indigo-300/10 dark:bg-indigo-500/10 '} rounded-full flex p-2.5"
: 'bg-indigo-300/10 dark:bg-indigo-500/10 '} rounded-full flex {className}"
>
<div class="flex items-center mr-1">
<button

Some files were not shown because too many files have changed in this diff Show more