This commit is contained in:
Andrew Baek 2024-10-28 14:09:56 +09:00
commit b92d530cf3
243 changed files with 20480 additions and 5556 deletions

View file

@ -8,9 +8,19 @@ assignees: ''
# Bug Report # Bug Report
**Important: Before submitting a bug report, please check whether a similar issue or feature request has already been posted in the Issues or Discussions section. It's likely we're already tracking it. In case of uncertainty, initiate a discussion post first. This helps us all to efficiently focus on improving the project.** ## Important Notes
**Let's collaborate respectfully. If you bring negativity, please understand our capacity to engage may be limited. If you're open to learning and communicating constructively, we're more than happy to assist you. Remember, Open WebUI is a volunteer-driven project maintained by a single maintainer, supported by our amazing contributors who also manage full-time jobs. We respect your time; please respect ours. If you have an issue, We highly encourage you to submit a pull request or to fork the project. We actively work to prevent contributor burnout to preserve the quality and continuity of Open WebUI.** - **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If youre unsure, start a discussion post first. This will help us efficiently focus on improving the project.
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. Were here to help if youre open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
---
## Installation Method ## Installation Method

View file

@ -8,9 +8,19 @@ assignees: ''
# Feature Request # Feature Request
**Important: Before submitting a feature request, please check whether a similar issue or feature request has already been posted in the Issues or Discussions section. It's likely we're already tracking it. In case of uncertainty, initiate a discussion post first. This helps us all to efficiently focus on improving the project.** ## Important Notes
**Let's collaborate respectfully. If you bring negativity, please understand our capacity to engage may be limited. If you're open to learning and communicating constructively, we're more than happy to assist you. Remember, Open WebUI is a volunteer-driven project maintained by a single maintainer, supported by our amazing contributors who also manage full-time jobs. We respect your time; please respect ours. If you have an issue, We highly encourage you to submit a pull request or to fork the project. We actively work to prevent contributor burnout to preserve the quality and continuity of Open WebUI.** - **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If youre unsure, start a discussion post first. This will help us efficiently focus on improving the project.
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. Were here to help if youre open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
---
**Is your feature request related to a problem? Please describe.** **Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

View file

@ -5,6 +5,73 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.35] - 2024-10-26
### Added
- **📁 Robust File Handling**: Enhanced file input handling for chat. If the content extraction fails or is empty, users will now receive a clear warning, preventing silent failures and ensuring you always know what's happening with your uploads.
- **🌍 New Language Support**: Introduced Hungarian translations and updated French translations, expanding the platform's language accessibility for a more global user base.
### Fixed
- **📚 Knowledge Base Loading Issue**: Resolved a critical bug where the Knowledge Base was not loading, ensuring smooth access to your stored documents and improving information retrieval in RAG-enhanced workflows.
- **🛠️ Tool Parameters Issue**: Fixed an error where tools were not functioning correctly when required parameters were missing, ensuring reliable tool performance and more efficient task completions.
- **🔗 Merged Response Loss in Multi-Model Chats**: Addressed an issue where responses in multi-model chat workflows were being deleted after follow-up queries, improving consistency and ensuring smoother interactions across models.
## [0.3.34] - 2024-10-26
### Added
- **🔧 Feedback Export Enhancements**: Feedback history data can now be exported to JSON, allowing for seamless integration in RLHF processing and further analysis.
- **🗂️ Embedding Model Lazy Loading**: Search functionality for leaderboard reranking is now more efficient, as embedding models are lazy-loaded only when needed, optimizing performance.
- **🎨 Rich Text Input Toggle**: Users can now switch back to legacy textarea input for chat if they prefer simpler text input, though rich text is still the default until deprecation.
- **🛠️ Improved Tool Calling Mechanism**: Enhanced method for parsing and calling tools, improving the reliability and robustness of tool function calls.
- **🌐 Globalization Enhancements**: Updates to internationalization (i18n) support, further refining multi-language compatibility and accuracy.
### Fixed
- **🖥️ Folder Rename Fix for Firefox**: Addressed a persistent issue where users could not rename folders by pressing enter in Firefox, now ensuring seamless folder management across browsers.
- **🔠 Tiktoken Model Text Splitter Issue**: Resolved an issue where the tiktoken text splitter wasnt working in Docker installations, restoring full functionality for tokenized text editing.
- **💼 S3 File Upload Issue**: Fixed a problem affecting S3 file uploads, ensuring smooth operations for those who store files on cloud storage.
- **🔒 Strict-Transport-Security Crash**: Resolved a crash when setting the Strict-Transport-Security (HSTS) header, improving stability and security enhancements.
- **🚫 OIDC Boolean Access Fix**: Addressed an issue with boolean values not being accessed correctly during OIDC logins, ensuring login reliability.
- **⚙️ Rich Text Paste Behavior**: Refined paste behavior in rich text input to make it smoother and more intuitive when pasting various content types.
- **🔨 Model Exclusion for Arena Fix**: Corrected the filter function that was not properly excluding models from the arena, improving model management.
- **🏷️ "Tags Generation Prompt" Fix**: Addressed an issue preventing custom "tags generation prompts" from registering properly, ensuring custom prompt work seamlessly.
## [0.3.33] - 2024-10-24
### Added
- **🏆 Evaluation Leaderboard**: Easily track your performance through a new leaderboard system where your ratings contribute to a real-time ranking based on the Elo system. Sibling responses (regenerations, many model chats) are required for your ratings to count in the leaderboard. Additionally, you can opt-in to share your feedback history and be part of the community-wide leaderboard. Expect further improvements as we refine the algorithm—help us build the best community leaderboard!
- **⚔️ Arena Model Evaluation**: Enable blind A/B testing of models directly from Admin Settings > Evaluation for a true side-by-side comparison. Ideal for pinpointing the best model for your needs.
- **🎯 Topic-Based Leaderboard**: Discover more accurate rankings with experimental topic-based reranking, which adjusts leaderboard standings based on tag similarity in feedback. Get more relevant insights based on specific topics!
- **📁 Folders Support for Chats**: Organize your chats better by grouping them into folders. Drag and drop chats between folders and export them seamlessly for easy sharing or analysis.
- **📤 Easy Chat Import via Drag & Drop**: Save time by simply dragging and dropping chat exports (JSON) directly onto the sidebar to import them into your workspace—streamlined, efficient, and intuitive!
- **📚 Enhanced Knowledge Collection**: Now, you can reference individual files from a knowledge collection—ideal for more precise Retrieval-Augmented Generations (RAG) queries and document analysis.
- **🏷️ Enhanced Tagging System**: Tags now take up less space! Utilize the new 'tag:' query system to manage, search, and organize your conversations more effectively without cluttering the interface.
- **🧠 Auto-Tagging for Chats**: Your conversations are now automatically tagged for improved organization, mirroring the efficiency of auto-generated titles.
- **🔍 Backend Chat Query System**: Chat filtering has become more efficient, now handled through the backend\*\* instead of your browser, improving search performance and accuracy.
- **🎮 Revamped Playground**: Experience a refreshed and optimized Playground for smoother testing, tweaks, and experimentation of your models and tools.
- **🧩 Token-Based Text Splitter**: Introducing token-based text splitting (tiktoken), giving you more precise control over how text is processed. Previously, only character-based splitting was available.
- **🔢 Ollama Batch Embeddings**: Leverage new batch embedding support for improved efficiency and performance with Ollama embedding models.
- **🔍 Enhanced Add Text Content Modal**: Enjoy a cleaner, more intuitive workflow for adding and curating knowledge content with an upgraded input modal from our Knowledge workspace.
- **🖋️ Rich Text Input for Chats**: Make your chat inputs more dynamic with support for rich text formatting. Your conversations just got a lot more polished and professional.
- **⚡ Faster Whisper Model Configurability**: Customize your local faster whisper model directly from the WebUI.
- **☁️ Experimental S3 Support**: Enable stateless WebUI instances with S3 support, greatly enhancing scalability and balancing heavy workloads.
- **🔕 Disable Update Toast**: Now you can streamline your workspace even further—choose to disable update notifications for a more focused experience.
- **🌟 RAG Citation Relevance Percentage**: Easily assess citation accuracy with the addition of relevance percentages in RAG results.
- **⚙️ Mermaid Copy Button**: Mermaid diagrams now come with a handy copy button, simplifying the extraction and use of diagram contents directly in your workflow.
- **🎨 UI Redesign**: Major interface redesign that will make navigation smoother, keep your focus where it matters, and ensure a modern look.
### Fixed
- **🎙️ Voice Note Mic Stopping Issue**: Fixed the issue where the microphone stayed active after ending a voice note recording, ensuring your audio workflow runs smoothly.
### Removed
- **👋 Goodbye Sidebar Tags**: Sidebar tag clutter is gone. Weve shifted tag buttons to more effective query-based tag filtering for a sleeker, more agile interface.
## [0.3.32] - 2024-10-06 ## [0.3.32] - 2024-10-06
### Added ### Added

View file

@ -1,6 +1,6 @@
# syntax=docker/dockerfile:1 # syntax=docker/dockerfile:1
# Initialize device type args # Initialize device type args
# use build args in the docker build commmand with --build-arg="BUILDARG=true" # use build args in the docker build command with --build-arg="BUILDARG=true"
ARG USE_CUDA=false ARG USE_CUDA=false
ARG USE_OLLAMA=false ARG USE_OLLAMA=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
@ -11,6 +11,10 @@ ARG USE_CUDA_VER=cu121
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL="" ARG USE_RERANKING_MODEL=""
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
ARG BUILD_HASH=dev-build ARG BUILD_HASH=dev-build
# Override at your own risk - non-root configurations are untested # Override at your own risk - non-root configurations are untested
ARG UID=0 ARG UID=0
@ -72,6 +76,10 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
## Tiktoken model settings ##
ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
## Hugging Face download cache ## ## Hugging Face download cache ##
ENV HF_HOME="/app/backend/data/cache/embedding/models" ENV HF_HOME="/app/backend/data/cache/embedding/models"
@ -131,11 +139,13 @@ RUN pip3 install uv && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
else \ else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
fi; \ fi; \
chown -R $UID:$GID /app/backend/data/ chown -R $UID:$GID /app/backend/data/

View file

@ -18,7 +18,7 @@ If you're experiencing connection issues, its often due to the WebUI docker c
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
``` ```
### Error on Slow Reponses for Ollama ### Error on Slow Responses for Ollama
Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds. Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.

View file

@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.faster_whisper_model = None
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
def set_faster_whisper_model(model: str, auto_update: bool = False):
if model and app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
else:
app.state.faster_whisper_model = None
class TTSConfigForm(BaseModel): class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str OPENAI_API_KEY: str
@ -99,6 +127,7 @@ class STTConfigForm(BaseModel):
OPENAI_API_KEY: str OPENAI_API_KEY: str
ENGINE: str ENGINE: str
MODEL: str MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel): class AudioConfigUpdateForm(BaseModel):
@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)):
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE, "ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL, "MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
}, },
} }
@ -176,6 +206,8 @@ async def update_audio_config(
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL app.state.config.STT_MODEL = form_data.stt.MODEL
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
return { return {
"tts": { "tts": {
@ -194,6 +226,7 @@ async def update_audio_config(
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE, "ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL, "MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
}, },
} }
@ -367,27 +400,10 @@ def transcribe(file_path):
id = filename.split(".")[0] id = filename.split(".")[0]
if app.state.config.STT_ENGINE == "": if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel if app.state.faster_whisper_model is None:
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
model = app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5) segments, info = model.transcribe(file_path, beam_size=5)
log.info( log.info(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
@ -395,7 +411,6 @@ def transcribe(file_path):
) )
transcript = "".join([segment.text for segment in list(segments)]) transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()} data = {"text": transcript.strip()}
# save the transcript to a json file # save the transcript to a json file
@ -403,7 +418,7 @@ def transcribe(file_path):
with open(transcript_file, "w") as f: with open(transcript_file, "w") as f:
json.dump(data, f) json.dump(data, f)
print(data) log.debug(data)
return data return data
elif app.state.config.STT_ENGINE == "openai": elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path): if is_mp4_audio(file_path):
@ -417,7 +432,7 @@ def transcribe(file_path):
files = {"file": (filename, open(file_path, "rb"))} files = {"file": (filename, open(file_path, "rb"))}
data = {"model": app.state.config.STT_MODEL} data = {"model": app.state.config.STT_MODEL}
print(files, data) log.debug(files, data)
r = None r = None
try: try:
@ -507,7 +522,8 @@ def transcription(
else: else:
data = transcribe(file_path) data = transcribe(file_path)
return data file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(

View file

@ -125,22 +125,34 @@ async def comfyui_generate_image(
workflow[node_id]["inputs"][node.key] = model workflow[node_id]["inputs"][node.key] = model
elif node.type == "prompt": elif node.type == "prompt":
for node_id in node.node_ids: for node_id in node.node_ids:
workflow[node_id]["inputs"]["text"] = payload.prompt workflow[node_id]["inputs"][
node.key if node.key else "text"
] = payload.prompt
elif node.type == "negative_prompt": elif node.type == "negative_prompt":
for node_id in node.node_ids: for node_id in node.node_ids:
workflow[node_id]["inputs"]["text"] = payload.negative_prompt workflow[node_id]["inputs"][
node.key if node.key else "text"
] = payload.negative_prompt
elif node.type == "width": elif node.type == "width":
for node_id in node.node_ids: for node_id in node.node_ids:
workflow[node_id]["inputs"]["width"] = payload.width workflow[node_id]["inputs"][
node.key if node.key else "width"
] = payload.width
elif node.type == "height": elif node.type == "height":
for node_id in node.node_ids: for node_id in node.node_ids:
workflow[node_id]["inputs"]["height"] = payload.height workflow[node_id]["inputs"][
node.key if node.key else "height"
] = payload.height
elif node.type == "n": elif node.type == "n":
for node_id in node.node_ids: for node_id in node.node_ids:
workflow[node_id]["inputs"]["batch_size"] = payload.n workflow[node_id]["inputs"][
node.key if node.key else "batch_size"
] = payload.n
elif node.type == "steps": elif node.type == "steps":
for node_id in node.node_ids: for node_id in node.node_ids:
workflow[node_id]["inputs"]["steps"] = payload.steps workflow[node_id]["inputs"][
node.key if node.key else "steps"
] = payload.steps
elif node.type == "seed": elif node.type == "seed":
seed = ( seed = (
payload.seed payload.seed

View file

@ -547,7 +547,7 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel): class GenerateEmbedForm(BaseModel):
model: str model: str
input: list[str] input: list[str] | str
truncate: Optional[bool] = None truncate: Optional[bool] = None
options: Optional[dict] = None options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
@ -692,7 +692,7 @@ class GenerateCompletionForm(BaseModel):
options: Optional[dict] = None options: Optional[dict] = None
system: Optional[str] = None system: Optional[str] = None
template: Optional[str] = None template: Optional[str] = None
context: Optional[str] = None context: Optional[list[int]] = None
stream: Optional[bool] = True stream: Optional[bool] = True
raw: Optional[bool] = None raw: Optional[bool] = None
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
@ -739,7 +739,7 @@ class GenerateChatCompletionForm(BaseModel):
format: Optional[str] = None format: Optional[str] = None
options: Optional[dict] = None options: Optional[dict] = None
template: Optional[str] = None template: Optional[str] = None
stream: Optional[bool] = None stream: Optional[bool] = True
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
@ -761,6 +761,7 @@ async def generate_chat_completion(
form_data: GenerateChatCompletionForm, form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
): ):
payload = {**form_data.model_dump(exclude_none=True)} payload = {**form_data.model_dump(exclude_none=True)}
log.debug(f"generate_chat_completion() - 1.payload = {payload}") log.debug(f"generate_chat_completion() - 1.payload = {payload}")
@ -769,7 +770,7 @@ async def generate_chat_completion(
model_id = form_data.model model_id = form_data.model
if app.state.config.ENABLE_MODEL_FILTER: if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,

View file

@ -18,7 +18,10 @@ from open_webui.config import (
OPENAI_API_KEYS, OPENAI_API_KEYS,
AppConfig, AppConfig,
) )
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
)
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -179,7 +182,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async def fetch_url(url, key): async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=3) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try: try:
headers = {"Authorization": f"Bearer {key}"} headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
@ -237,9 +240,7 @@ def merge_models_lists(model_lists):
def is_openai_api_disabled(): def is_openai_api_disabled():
api_keys = app.state.config.OPENAI_API_KEYS return not app.state.config.ENABLE_OPENAI_API
no_keys = len(api_keys) == 1 and api_keys[0] == ""
return no_keys or not app.state.config.ENABLE_OPENAI_API
async def get_all_models_raw() -> list: async def get_all_models_raw() -> list:

View file

@ -14,7 +14,11 @@ from typing import Iterator, Optional, Sequence, Union
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
import tiktoken
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.knowledge import Knowledges
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
# Document loaders # Document loaders
@ -47,6 +51,8 @@ from open_webui.apps.retrieval.utils import (
from open_webui.apps.webui.models.files import Files from open_webui.apps.webui.models.files import Files
from open_webui.config import ( from open_webui.config import (
BRAVE_SEARCH_API_KEY, BRAVE_SEARCH_API_KEY,
TIKTOKEN_ENCODING_NAME,
RAG_TEXT_SPLITTER,
CHUNK_OVERLAP, CHUNK_OVERLAP,
CHUNK_SIZE, CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE, CONTENT_EXTRACTION_ENGINE,
@ -102,7 +108,7 @@ from open_webui.utils.misc import (
) )
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
YoutubeLoader, YoutubeLoader,
) )
@ -129,6 +135,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
@ -171,9 +180,9 @@ def update_embedding_model(
auto_update: bool = False, auto_update: bool = False,
): ):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers from sentence_transformers import SentenceTransformer
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( app.state.sentence_transformer_ef = SentenceTransformer(
get_model_path(embedding_model, auto_update), get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE, device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@ -384,18 +393,19 @@ async def get_rag_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"file": {
"max_size": app.state.config.FILE_MAX_SIZE,
"max_count": app.state.config.FILE_MAX_COUNT,
},
"content_extraction": { "content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE, "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL, "tika_server_url": app.state.config.TIKA_SERVER_URL,
}, },
"chunk": { "chunk": {
"text_splitter": app.state.config.TEXT_SPLITTER,
"chunk_size": app.state.config.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
}, },
"file": {
"max_size": app.state.config.FILE_MAX_SIZE,
"max_count": app.state.config.FILE_MAX_COUNT,
},
"youtube": { "youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION, "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
@ -434,6 +444,7 @@ class ContentExtractionConfig(BaseModel):
class ChunkParamUpdateForm(BaseModel): class ChunkParamUpdateForm(BaseModel):
text_splitter: Optional[str] = None
chunk_size: int chunk_size: int
chunk_overlap: int chunk_overlap: int
@ -493,6 +504,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None: if form_data.chunk is not None:
app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
@ -539,6 +551,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"tika_server_url": app.state.config.TIKA_SERVER_URL, "tika_server_url": app.state.config.TIKA_SERVER_URL,
}, },
"chunk": { "chunk": {
"text_splitter": app.state.config.TEXT_SPLITTER,
"chunk_size": app.state.config.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
}, },
@ -599,11 +612,10 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings( async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user) form_data: QuerySettingsForm, user=Depends(get_admin_user)
): ):
app.state.config.RAG_TEMPLATE = ( app.state.config.RAG_TEMPLATE = form_data.template
form_data.template if form_data.template != "" else DEFAULT_RAG_TEMPLATE
)
app.state.config.TOP_K = form_data.k if form_data.k else 4 app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False form_data.hybrid if form_data.hybrid else False
) )
@ -648,18 +660,46 @@ def save_docs_to_vector_db(
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split: if split:
if app.state.config.TEXT_SPLITTER in ["", "character"]:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
elif app.state.config.TEXT_SPLITTER == "token":
log.info(
f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}"
)
tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME))
text_splitter = TokenTextSplitter(
encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME),
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
docs = text_splitter.split_documents(docs) docs = text_splitter.split_documents(docs)
if len(docs) == 0: if len(docs) == 0:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] metadatas = [
{
**doc.metadata,
**(metadata if metadata else {}),
"embedding_config": json.dumps(
{
"engine": app.state.config.RAG_EMBEDDING_ENGINE,
"model": app.state.config.RAG_EMBEDDING_MODEL,
}
),
}
for doc in docs
]
# ChromaDB does not like datetime formats # ChromaDB does not like datetime formats
# for meta-data so convert them to string. # for meta-data so convert them to string.
@ -675,8 +715,10 @@ def save_docs_to_vector_db(
if overwrite: if overwrite:
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
log.info(f"deleting existing collection {collection_name}") log.info(f"deleting existing collection {collection_name}")
elif add is False:
if add is False: log.info(
f"collection {collection_name} already exists, overwrite is False and add is False"
)
return True return True
log.info(f"adding to collection {collection_name}") log.info(f"adding to collection {collection_name}")
@ -788,15 +830,14 @@ def process_file(
else: else:
# Process the file and save the content # Process the file and save the content
# Usage: /files/ # Usage: /files/
file_path = file.path
file_path = file.meta.get("path", None)
if file_path: if file_path:
file_path = Storage.get_file(file_path)
loader = Loader( loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE, engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
) )
docs = loader.load( docs = loader.load(
file.filename, file.meta.get("content_type"), file_path file.filename, file.meta.get("content_type"), file_path
) )
@ -812,7 +853,6 @@ def process_file(
}, },
) )
] ]
text_content = " ".join([doc.page_content for doc in docs]) text_content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {text_content}") log.debug(f"text_content: {text_content}")
@ -1255,6 +1295,7 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
@app.post("/reset/db") @app.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)): def reset_vector_db(user=Depends(get_admin_user)):
VECTOR_DB_CLIENT.reset() VECTOR_DB_CLIENT.reset()
Knowledges.delete_all_knowledge()
@app.post("/reset/uploads") @app.post("/reset/uploads")
@ -1277,28 +1318,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
print(f"The directory {folder} does not exist") print(f"The directory {folder} does not exist")
except Exception as e: except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}") print(f"Failed to process the directory {folder}. Reason: {e}")
return True
@app.post("/reset")
def reset(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
try:
VECTOR_DB_CLIENT.reset()
except Exception as e:
log.exception(e)
return True return True

View file

@ -19,6 +19,7 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import DEFAULT_RAG_TEMPLATE
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -239,8 +240,13 @@ def query_collection_with_hybrid_search(
def rag_template(template: str, context: str, query: str): def rag_template(template: str, context: str, query: str):
count = template.count("[context]") if template == "":
assert "[context]" in template, "RAG template does not contain '[context]'" template = DEFAULT_RAG_TEMPLATE
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context: if "<context>" in context and "</context>" in context:
log.debug( log.debug(
@ -249,14 +255,25 @@ def rag_template(template: str, context: str, query: str):
"nothing, or the user might be trying to hack something." "nothing, or the user might be trying to hack something."
) )
query_placeholders = []
if "[query]" in context: if "[query]" in context:
query_placeholder = f"[query-{str(uuid.uuid4())}]" query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder) template = template.replace("[query]", query_placeholder)
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context) template = template.replace("[context]", context)
template = template.replace(query_placeholder, query) template = template.replace("{{CONTEXT}}", context)
else:
template = template.replace("[context]", context)
template = template.replace("[query]", query) template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
return template return template
@ -271,27 +288,22 @@ def get_embedding_function(
if embedding_engine == "": if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist() return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]: elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama": func = lambda query: generate_embeddings(
func = lambda query: generate_ollama_embeddings( engine=embedding_engine,
model=embedding_model,
input=query,
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model, model=embedding_model,
text=query, text=query,
key=openai_key, key=openai_key if embedding_engine == "openai" else "",
url=openai_url, url=openai_url if embedding_engine == "openai" else "",
) )
def generate_multiple(query, f): def generate_multiple(query, func):
if isinstance(query, list): if isinstance(query, list):
embeddings = [] embeddings = []
for i in range(0, len(query), embedding_batch_size): for i in range(0, len(query), embedding_batch_size):
embeddings.extend(f(query[i : i + embedding_batch_size])) embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings return embeddings
else: else:
return f(query) return func(query)
return lambda query: generate_multiple(query, func) return lambda query: generate_multiple(query, func)
@ -373,6 +385,8 @@ def get_rag_context(
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)
if context: if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file}) relevant_contexts.append({**context, "file": file})
contexts = [] contexts = []
@ -380,23 +394,37 @@ def get_rag_context(
for context in relevant_contexts: for context in relevant_contexts:
try: try:
if "documents" in context: if "documents" in context:
file_names = list(
set(
[
metadata["name"]
for metadata in context["metadatas"][0]
if metadata is not None and "name" in metadata
]
)
)
contexts.append( contexts.append(
"\n\n".join( ((", ".join(file_names) + ":\n\n") if file_names else "")
+ "\n\n".join(
[text for text in context["documents"][0] if text is not None] [text for text in context["documents"][0] if text is not None]
) )
) )
if "metadatas" in context: if "metadatas" in context:
citations.append( citation = {
{
"source": context["file"], "source": context["file"],
"document": context["documents"][0], "document": context["documents"][0],
"metadata": context["metadatas"][0], "metadata": context["metadatas"][0],
} }
) if "distances" in context and context["distances"]:
citation["distances"] = context["distances"][0]
citations.append(citation)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
print("contexts", contexts)
print("citations", citations)
return contexts, citations return contexts, citations
@ -438,20 +466,6 @@ def get_model_path(model: str, update_model: bool = False):
return model return model
def generate_openai_embeddings(
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings( def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]: ) -> Optional[list[list[float]]]:
@ -475,19 +489,31 @@ def generate_openai_batch_embeddings(
return None return None
def generate_ollama_embeddings( def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
model: str, input: list[str] if engine == "ollama":
) -> Optional[list[list[float]]]: if isinstance(text, list):
if isinstance(input, list):
embeddings = generate_ollama_batch_embeddings( embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": input}) GenerateEmbedForm(**{"model": model, "input": text})
) )
else: else:
embeddings = generate_ollama_batch_embeddings( embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [input]}) GenerateEmbedForm(**{"model": model, "input": [text]})
) )
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
return embeddings["embeddings"] if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
import operator import operator

View file

@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient() VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
else: else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient

View file

@ -13,11 +13,22 @@ from open_webui.config import (
CHROMA_HTTP_SSL, CHROMA_HTTP_SSL,
CHROMA_TENANT, CHROMA_TENANT,
CHROMA_DATABASE, CHROMA_DATABASE,
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
) )
class ChromaClient: class ChromaClient:
def __init__(self): def __init__(self):
settings_dict = {
"allow_reset": True,
"anonymized_telemetry": False,
}
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = CHROMA_CLIENT_AUTH_CREDENTIALS
if CHROMA_HTTP_HOST != "": if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient( self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST, host=CHROMA_HTTP_HOST,
@ -26,12 +37,12 @@ class ChromaClient:
ssl=CHROMA_HTTP_SSL, ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT, tenant=CHROMA_TENANT,
database=CHROMA_DATABASE, database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(**settings_dict),
) )
else: else:
self.client = chromadb.PersistentClient( self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(**settings_dict),
tenant=CHROMA_TENANT, tenant=CHROMA_TENANT,
database=CHROMA_DATABASE, database=CHROMA_DATABASE,
) )
@ -109,7 +120,9 @@ class ChromaClient:
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name) collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]
@ -127,7 +140,9 @@ class ChromaClient:
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name) collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]

View file

@ -0,0 +1,179 @@
from typing import Optional
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI
NO_LIMIT = 999999999
class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _create_collection(self, collection_name: str, dimension: int):
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
),
)
print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
def _create_points(self, items: list[VectorItem]):
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={"text": item["text"], "metadata": item["metadata"]},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(
f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
return self.client.delete_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
query_response = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query=vectors[0],
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
if not self.has_collection(collection_name):
return None
try:
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query_filter=models.Filter(should=field_conditions),
limit=limit,
)
return self._result_to_get_result(points.points)
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points.points)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
field_conditions = []
if ids:
for id_value in ids:
field_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
),
elif filter:
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
),
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions)
),
)
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)

View file

@ -1,6 +1,7 @@
import inspect import inspect
import json import json
import logging import logging
import time
from typing import AsyncGenerator, Generator, Iterator from typing import AsyncGenerator, Generator, Iterator
from open_webui.apps.socket.main import get_event_call, get_event_emitter from open_webui.apps.socket.main import get_event_call, get_event_emitter
@ -9,6 +10,7 @@ from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.routers import ( from open_webui.apps.webui.routers import (
auths, auths,
chats, chats,
folders,
configs, configs,
files, files,
functions, functions,
@ -16,6 +18,7 @@ from open_webui.apps.webui.routers import (
models, models,
knowledge, knowledge,
prompts, prompts,
evaluations,
tools, tools,
users, users,
utils, utils,
@ -31,10 +34,17 @@ from open_webui.config import (
ENABLE_LOGIN_FORM, ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING, ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP, ENABLE_SIGNUP,
ENABLE_EVALUATION_ARENA_MODELS,
EVALUATION_ARENA_MODELS,
DEFAULT_ARENA_MODEL,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM, OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM, OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL, WEBHOOK_URL,
@ -89,10 +99,18 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.MODELS = {} app.state.MODELS = {}
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {} app.state.FUNCTIONS = {}
@ -107,19 +125,24 @@ app.add_middleware(
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
@ -133,6 +156,47 @@ async def get_status():
} }
async def get_all_models():
models = []
pipe_models = await get_pipe_models()
models = models + pipe_models
if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = []
if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [
{
"id": model["id"],
"name": model["name"],
"info": {
"meta": model["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
for model in app.state.config.EVALUATION_ARENA_MODELS
]
else:
# Add default arena model
arena_models = [
{
"id": DEFAULT_ARENA_MODEL["id"],
"name": DEFAULT_ARENA_MODEL["name"],
"info": {
"meta": DEFAULT_ARENA_MODEL["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
]
models = models + arena_models
return models
def get_function_module(pipe_id: str): def get_function_module(pipe_id: str):
# Check if function is already loaded # Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS: if pipe_id not in app.state.FUNCTIONS:
@ -335,7 +399,7 @@ async def generate_function_chat_completion(form_data, user):
pipe = function_module.pipe pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params) params = get_function_params(function_module, form_data, user, extra_params)
if form_data["stream"]: if form_data.get("stream", False):
async def stream_content(): async def stream_content():
try: try:

View file

@ -4,10 +4,13 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
#################### ####################
# Chat DB Schema # Chat DB Schema
@ -27,6 +30,10 @@ class Chat(Base):
share_id = Column(Text, unique=True, nullable=True) share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False) archived = Column(Boolean, default=False)
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True)
class ChatModel(BaseModel): class ChatModel(BaseModel):
@ -42,6 +49,10 @@ class ChatModel(BaseModel):
share_id: Optional[str] = None share_id: Optional[str] = None
archived: bool = False archived: bool = False
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
#################### ####################
@ -53,6 +64,17 @@ class ChatForm(BaseModel):
chat: dict chat: dict
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
folder_id: Optional[str] = None
class ChatTitleMessagesForm(BaseModel):
title: str
messages: list[dict]
class ChatTitleForm(BaseModel): class ChatTitleForm(BaseModel):
title: str title: str
@ -66,6 +88,9 @@ class ChatResponse(BaseModel):
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared share_id: Optional[str] = None # id of the chat to be shared
archived: bool archived: bool
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
class ChatTitleIdResponse(BaseModel): class ChatTitleIdResponse(BaseModel):
@ -100,6 +125,35 @@ class ChatTable:
db.refresh(result) db.refresh(result)
return ChatModel.model_validate(result) if result else None return ChatModel.model_validate(result) if result else None
def import_chat(
self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
@ -184,11 +238,24 @@ class ChatTable:
except Exception: except Exception:
return None return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
chat.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(chat) db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
@ -225,14 +292,18 @@ class ChatTable:
limit: int = 50, limit: int = 50,
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
all_chats = (
query.order_by(Chat.updated_at.desc()) query = query.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all() if skip:
) query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_title_id_list_by_user_id( def get_chat_title_id_list_by_user_id(
@ -243,7 +314,9 @@ class ChatTable:
limit: Optional[int] = None, limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]: ) -> list[ChatTitleIdResponse]:
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
@ -330,6 +403,15 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
@ -351,37 +433,316 @@ class ChatTable:
Filters chats based on a search query using Python, allowing pagination using skip and limit. Filters chats based on a search query using Python, allowing pagination using skip and limit.
""" """
search_text = search_text.lower().strip() search_text = search_text.lower().strip()
if not search_text: if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
search_text_words = search_text.split(" ")
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
tag_ids = [
word.replace("tag:", "").replace(" ", "_").lower()
for word in search_text_words
if word.startswith("tag:")
]
search_text_words = [
word for word in search_text_words if not word.startswith("tag:")
]
search_text = " ".join(search_text_words)
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id) query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived: if not include_archived:
query = query.filter(Chat.archived == False) query = query.filter(Chat.archived == False)
# Fetch all potentially relevant chats query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
# Filter chats using Python # Check if the database dialect is either 'sqlite' or 'postgresql'
filtered_chats = [] dialect_name = db.bind.dialect.name
for chat in all_chats: if dialect_name == "sqlite":
# Check chat title # SQLite case: using JSON1 extension for JSON searching
title_matches = search_text in chat.title.lower() query = query.filter(
(
# Check chat content in chat JSON Chat.title.ilike(
content_matches = any( f"%{search_text}%"
search_text in message.get("content", "").lower() ) # Case-insensitive search in title
for message in chat.chat.get("messages", []) | text(
if "content" in message """
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
) )
if title_matches or content_matches: # Check if there are any tags to filter, it should have all the tags
filtered_chats.append(chat) if "none" in tag_ids:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
)
"""
)
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
WHERE tag.value = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
# Implementing pagination manually elif dialect_name == "postgresql":
paginated_chats = filtered_chats[skip : skip + limit] # PostgreSQL relies on proper JSON query for search
return [ChatModel.model_validate(chat) for chat in paginated_chats] query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
)
"""
)
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
WHERE tag = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
print(len(all_chats))
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter(
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.folder_id = folder_id
chat.updated_at = int(time.time())
chat.pinned = False
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name)
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
all_chats = query.all()
print("all_chats", all_chats)
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
try:
with get_db() as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Get the count of matching records
count = query.count()
# Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count)
return count
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower()
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": list(set(tags)),
}
db.commit()
return True
except Exception:
return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.meta = {
**chat.meta,
"tags": [],
}
db.commit()
return True
except Exception:
return False
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
@ -415,6 +776,18 @@ class ChatTable:
except Exception: except Exception:
return False return False
def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:

View file

@ -0,0 +1,254 @@
import logging
import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Feedback DB Schema
####################
class Feedback(Base):
__tablename__ = "feedback"
id = Column(Text, primary_key=True)
user_id = Column(Text)
version = Column(BigInteger, default=0)
type = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
snapshot = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class FeedbackModel(BaseModel):
id: str
user_id: str
version: int
type: str
data: Optional[dict] = None
meta: Optional[dict] = None
snapshot: Optional[dict] = None
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FeedbackResponse(BaseModel):
id: str
user_id: str
version: int
type: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int
updated_at: int
class RatingData(BaseModel):
rating: Optional[str | int] = None
model_id: Optional[str] = None
sibling_model_ids: Optional[list[str]] = None
reason: Optional[str] = None
comment: Optional[str] = None
model_config = ConfigDict(extra="allow", protected_namespaces=())
class MetaData(BaseModel):
arena: Optional[bool] = None
chat_id: Optional[str] = None
message_id: Optional[str] = None
tags: Optional[list[str]] = None
model_config = ConfigDict(extra="allow")
class SnapshotData(BaseModel):
chat: Optional[dict] = None
model_config = ConfigDict(extra="allow")
class FeedbackForm(BaseModel):
type: str
data: Optional[RatingData] = None
meta: Optional[dict] = None
snapshot: Optional[SnapshotData] = None
model_config = ConfigDict(extra="allow")
class FeedbackTable:
def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
id = str(uuid.uuid4())
feedback = FeedbackModel(
**{
"id": id,
"user_id": user_id,
"version": 0,
**form_data.model_dump(),
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Feedback(**feedback.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FeedbackModel.model_validate(result)
else:
return None
except Exception as e:
print(e)
return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
try:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return None
return FeedbackModel.model_validate(feedback)
except Exception:
return None
def get_feedback_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FeedbackModel]:
try:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return None
return FeedbackModel.model_validate(feedback)
except Exception:
return None
def get_all_feedbacks(self) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.order_by(Feedback.updated_at.desc())
.all()
]
def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(type=type)
.order_by(Feedback.updated_at.desc())
.all()
]
def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(user_id=user_id)
.order_by(Feedback.updated_at.desc())
.all()
]
def update_feedback_by_id(
self, id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return None
if form_data.data:
feedback.data = form_data.data.model_dump()
if form_data.meta:
feedback.meta = form_data.meta
if form_data.snapshot:
feedback.snapshot = form_data.snapshot.model_dump()
feedback.updated_at = int(time.time())
db.commit()
return FeedbackModel.model_validate(feedback)
def update_feedback_by_id_and_user_id(
self, id: str, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return None
if form_data.data:
feedback.data = form_data.data.model_dump()
if form_data.meta:
feedback.meta = form_data.meta
if form_data.snapshot:
feedback.snapshot = form_data.snapshot.model_dump()
feedback.updated_at = int(time.time())
db.commit()
return FeedbackModel.model_validate(feedback)
def delete_feedback_by_id(self, id: str) -> bool:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return False
db.delete(feedback)
db.commit()
return True
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return False
db.delete(feedback)
db.commit()
return True
def delete_feedbacks_by_user_id(self, user_id: str) -> bool:
with get_db() as db:
feedbacks = db.query(Feedback).filter_by(user_id=user_id).all()
if not feedbacks:
return False
for feedback in feedbacks:
db.delete(feedback)
db.commit()
return True
def delete_all_feedbacks(self) -> bool:
with get_db() as db:
feedbacks = db.query(Feedback).all()
if not feedbacks:
return False
for feedback in feedbacks:
db.delete(feedback)
db.commit()
return True
Feedbacks = FeedbackTable()

View file

@ -17,14 +17,15 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class File(Base): class File(Base):
__tablename__ = "file" __tablename__ = "file"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
user_id = Column(String) user_id = Column(String)
hash = Column(Text, nullable=True) hash = Column(Text, nullable=True)
filename = Column(Text) filename = Column(Text)
path = Column(Text, nullable=True)
data = Column(JSON, nullable=True) data = Column(JSON, nullable=True)
meta = Column(JSONField) meta = Column(JSON, nullable=True)
created_at = Column(BigInteger) created_at = Column(BigInteger)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
@ -38,11 +39,13 @@ class FileModel(BaseModel):
hash: Optional[str] = None hash: Optional[str] = None
filename: str filename: str
data: Optional[dict] = None path: Optional[str] = None
meta: dict
created_at: int # timestamp in epoch data: Optional[dict] = None
updated_at: int # timestamp in epoch meta: Optional[dict] = None
created_at: Optional[int] # timestamp in epoch
updated_at: Optional[int] # timestamp in epoch
#################### ####################
@ -50,6 +53,14 @@ class FileModel(BaseModel):
#################### ####################
class FileMeta(BaseModel):
name: Optional[str] = None
content_type: Optional[str] = None
size: Optional[int] = None
model_config = ConfigDict(extra="allow")
class FileModelResponse(BaseModel): class FileModelResponse(BaseModel):
id: str id: str
user_id: str user_id: str
@ -57,16 +68,26 @@ class FileModelResponse(BaseModel):
filename: str filename: str
data: Optional[dict] = None data: Optional[dict] = None
meta: dict meta: FileMeta
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
model_config = ConfigDict(extra="allow")
class FileMetadataResponse(BaseModel):
id: str
meta: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class FileForm(BaseModel): class FileForm(BaseModel):
id: str id: str
hash: Optional[str] = None hash: Optional[str] = None
filename: str filename: str
path: str
data: dict = {} data: dict = {}
meta: dict = {} meta: dict = {}
@ -104,6 +125,19 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
try:
file = db.get(File, id)
return FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
except Exception:
return None
def get_files(self) -> list[FileModel]: def get_files(self) -> list[FileModel]:
with get_db() as db: with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
@ -118,6 +152,21 @@ class FilesTable:
.all() .all()
] ]
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
with get_db() as db:
return [
FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
for file in db.query(File)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
]
def get_files_by_user_id(self, user_id: str) -> list[FileModel]: def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
with get_db() as db: with get_db() as db:
return [ return [

View file

@ -0,0 +1,271 @@
import logging
import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Folder DB Schema
####################
class Folder(Base):
__tablename__ = "folder"
id = Column(Text, primary_key=True)
parent_id = Column(Text, nullable=True)
user_id = Column(Text)
name = Column(Text)
items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class FolderModel(BaseModel):
id: str
parent_id: Optional[str] = None
user_id: str
name: str
items: Optional[dict] = None
meta: Optional[dict] = None
is_expanded: bool = False
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FolderForm(BaseModel):
name: str
model_config = ConfigDict(extra="allow")
class FolderTable:
def insert_new_folder(
self, user_id: str, name: str, parent_id: Optional[str] = None
) -> Optional[FolderModel]:
with get_db() as db:
id = str(uuid.uuid4())
folder = FolderModel(
**{
"id": id,
"user_id": user_id,
"name": name,
"parent_id": parent_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Folder(**folder.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FolderModel.model_validate(result)
else:
return None
except Exception as e:
print(e)
return None
def get_folder_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
return FolderModel.model_validate(folder)
except Exception:
return None
def get_children_folders_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folders = []
def get_children(folder):
children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id
)
for child in children:
get_children(child)
folders.append(child)
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
get_children(folder)
return folders
except Exception:
return None
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
with get_db() as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder).filter_by(user_id=user_id).all()
]
def get_folder_by_parent_id_and_user_id_and_name(
self, parent_id: Optional[str], user_id: str, name: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
# Check if folder exists
folder = (
db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id)
.filter(Folder.name.ilike(name))
.first()
)
if not folder:
return None
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
return None
def get_folders_by_parent_id_and_user_id(
self, parent_id: Optional[str], user_id: str
) -> list[FolderModel]:
with get_db() as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id)
.all()
]
def update_folder_parent_id_by_id_and_user_id(
self,
id: str,
user_id: str,
parent_id: str,
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
folder.parent_id = parent_id
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def update_folder_name_by_id_and_user_id(
self, id: str, user_id: str, name: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
existing_folder = (
db.query(Folder)
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
.first()
)
if existing_folder:
return None
folder.name = name
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def update_folder_is_expanded_by_id_and_user_id(
self, id: str, user_id: str, is_expanded: bool
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
folder.is_expanded = is_expanded
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return False
# Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
# Delete all children folders
def delete_children(folder):
folder_children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id
)
for folder_child in folder_children:
Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id
)
delete_children(folder_child)
folder = db.query(Folder).filter_by(id=folder_child.id).first()
db.delete(folder)
db.commit()
delete_children(folder)
db.delete(folder)
db.commit()
return True
except Exception as e:
log.error(f"delete_folder: {e}")
return False
Folders = FolderTable()

View file

@ -6,6 +6,10 @@ import uuid
from open_webui.apps.webui.internal.db import Base, get_db from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON from sqlalchemy import BigInteger, Column, String, Text, JSON
@ -64,6 +68,8 @@ class KnowledgeResponse(BaseModel):
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeForm(BaseModel): class KnowledgeForm(BaseModel):
name: str name: str
@ -148,5 +154,15 @@ class KnowledgeTable:
except Exception: except Exception:
return False return False
def delete_all_knowledge(self) -> bool:
with get_db() as db:
try:
db.query(Knowledge).delete()
db.commit()
return True
except Exception:
return False
Knowledges = KnowledgeTable() Knowledges = KnowledgeTable()

View file

@ -4,53 +4,35 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Tag DB Schema # Tag DB Schema
#################### ####################
class Tag(Base): class Tag(Base):
__tablename__ = "tag" __tablename__ = "tag"
id = Column(String)
id = Column(String, primary_key=True)
name = Column(String) name = Column(String)
user_id = Column(String) user_id = Column(String)
data = Column(Text, nullable=True) meta = Column(JSON, nullable=True)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
class ChatIdTag(Base): __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
__tablename__ = "chatidtag"
id = Column(String, primary_key=True)
tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel): class TagModel(BaseModel):
id: str id: str
name: str name: str
user_id: str user_id: str
data: Optional[str] = None meta: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel):
id: str
tag_name: str
chat_id: str
user_id: str
timestamp: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -59,23 +41,15 @@ class ChatIdTagModel(BaseModel):
#################### ####################
class ChatIdTagForm(BaseModel): class TagChatIdForm(BaseModel):
tag_name: str name: str
chat_id: str chat_id: str
class TagChatIdsResponse(BaseModel):
chat_ids: list[str]
class ChatTagsResponse(BaseModel):
tags: list[str]
class TagTable: class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
result = Tag(**tag.model_dump()) result = Tag(**tag.model_dump())
@ -86,177 +60,50 @@ class TagTable:
return TagModel.model_validate(result) return TagModel.model_validate(result)
else: else:
return None return None
except Exception: except Exception as e:
print(e)
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
self, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
id = name.replace(" ", "_").lower()
with get_db() as db: with get_db() as db:
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first() tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception: except Exception:
return None return None
def add_tag_to_chat(
self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag is None:
tag = self.insert_new_tag(form_data.tag_name, user_id)
id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel(
**{
"id": id,
"user_id": user_id,
"chat_id": form_data.chat_id,
"tag_name": tag.name,
"timestamp": int(time.time()),
}
)
try:
with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ChatIdTagModel.model_validate(result)
else:
return None
except Exception:
return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
with get_db() as db: with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (db.query(Tag).filter_by(user_id=user_id).all())
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
] ]
def get_tags_by_chat_id_and_user_id( def get_tags_by_ids_and_user_id(
self, chat_id: str, user_id: str self, ids: list[str], user_id: str
) -> list[TagModel]: ) -> list[TagModel]:
with get_db() as db: with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (
db.query(Tag) db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
) )
] ]
def get_chat_ids_by_tag_name_and_user_id( def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
self, tag_name: str, user_id: str
) -> list[ChatIdTagModel]:
with get_db() as db:
return [
ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
with get_db() as db:
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
res = ( id = name.replace(" ", "_").lower()
db.query(ChatIdTag) res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
)
log.debug(f"res: {res}") log.debug(f"res: {res}")
db.commit() db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
return False return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
)
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
tag.tag_name, chat_id, user_id
)
return True
Tags = TagTable() Tags = TagTable()

View file

@ -1,10 +1,13 @@
import re import re
import uuid import uuid
import time
import datetime
from open_webui.apps.webui.models.auths import ( from open_webui.apps.webui.models.auths import (
AddUserForm, AddUserForm,
ApiKey, ApiKey,
Auths, Auths,
Token,
SigninForm, SigninForm,
SigninResponse, SigninResponse,
SignupForm, SignupForm,
@ -18,6 +21,8 @@ from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import ( from open_webui.env import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
) )
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response from fastapi.responses import Response
@ -27,10 +32,12 @@ from open_webui.utils.utils import (
create_api_key, create_api_key,
create_token, create_token,
get_admin_user, get_admin_user,
get_verified_user,
get_current_user, get_current_user,
get_password_hash, get_password_hash,
) )
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from typing import Optional
router = APIRouter() router = APIRouter()
@ -39,23 +46,44 @@ router = APIRouter()
############################ ############################
@router.get("/", response_model=UserResponse) class SessionUserResponse(Token, UserResponse):
expires_at: Optional[int] = None
@router.get("/", response_model=SessionUserResponse)
async def get_session_user( async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user) request: Request, response: Response, user=Depends(get_current_user)
): ):
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), expires_delta=expires_delta,
)
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
) )
# Set the cookie token # Set the cookie token
response.set_cookie( response.set_cookie(
key="token", key="token",
value=token, value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
) )
return { return {
"token": token,
"token_type": "Bearer",
"expires_at": expires_at,
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
@ -71,7 +99,7 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserResponse) @router.post("/update/profile", response_model=UserResponse)
async def update_profile( async def update_profile(
form_data: UpdateProfileForm, session_user=Depends(get_current_user) form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
): ):
if session_user: if session_user:
user = Users.update_user_by_id( user = Users.update_user_by_id(
@ -114,7 +142,7 @@ async def update_password(
############################ ############################
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SessionUserResponse)
async def signin(request: Request, response: Response, form_data: SigninForm): async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
@ -156,21 +184,37 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), expires_delta=expires_delta,
)
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
) )
# Set the cookie token # Set the cookie token
response.set_cookie( response.set_cookie(
key="token", key="token",
value=token, value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
) )
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
"expires_at": expires_at,
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
@ -186,7 +230,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
############################ ############################
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm): async def signup(request: Request, response: Response, form_data: SignupForm):
if WEBUI_AUTH: if WEBUI_AUTH:
if ( if (
@ -226,16 +270,30 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
) )
if user: if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), expires_delta=expires_delta,
)
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
) )
# Set the cookie token # Set the cookie token
response.set_cookie( response.set_cookie(
key="token", key="token",
value=token, value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
) )
if request.app.state.config.WEBHOOK_URL: if request.app.state.config.WEBHOOK_URL:
@ -252,6 +310,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
"expires_at": expires_at,
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
@ -264,6 +323,12 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
@router.get("/signout")
async def signout(response: Response):
response.delete_cookie("token")
return {"status": True}
############################ ############################
# AddUser # AddUser
############################ ############################

View file

@ -4,16 +4,14 @@ from typing import Optional
from open_webui.apps.webui.models.chats import ( from open_webui.apps.webui.models.chats import (
ChatForm, ChatForm,
ChatImportForm,
ChatResponse, ChatResponse,
Chats, Chats,
ChatTitleIdResponse, ChatTitleIdResponse,
) )
from open_webui.apps.webui.models.tags import ( from open_webui.apps.webui.models.tags import TagModel, Tags
ChatIdTagForm, from open_webui.apps.webui.models.folders import Folders
ChatIdTagModel,
TagModel,
Tags,
)
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -103,6 +101,34 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
) )
############################
# ImportChat
############################
@router.post("/import", response_model=Optional[ChatResponse])
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
try:
chat = Chats.import_chat(user.id, form_data)
if chat:
tags = chat.meta.get("tags", [])
for tag_id in tags:
tag_id = tag_id.replace(" ", "_").lower()
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
if (
tag_id != "none"
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
):
Tags.insert_new_tag(tag_name, user.id)
return ChatResponse(**chat.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetChats # GetChats
############################ ############################
@ -118,13 +144,57 @@ async def search_user_chats(
limit = 60 limit = 60
skip = (page - 1) * limit skip = (page - 1) * limit
return [ chat_list = [
ChatTitleIdResponse(**chat.model_dump()) ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text( for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit user.id, text, skip=skip, limit=limit
) )
] ]
# Delete tag if no chat is found
words = text.strip().split(" ")
if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
tag_id = words[0].replace("tag:", "")
if len(chat_list) == 0:
if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
return chat_list
############################
# GetChatsByFolderId
############################
@router.get("/folder/{folder_id}", response_model=list[ChatResponse])
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
folder_ids = [folder_id]
children_folders = Folders.get_children_folders_by_id_and_user_id(
folder_id, user.id
)
if children_folders:
folder_ids.extend([folder.id for folder in children_folders])
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
]
############################
# GetPinnedChats
############################
@router.get("/pinned", response_model=list[ChatResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
]
############################ ############################
# GetChats # GetChats
@ -152,6 +222,23 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
] ]
############################
# GetAllTags
############################
@router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetAllChatsInDB # GetAllChatsInDB
############################ ############################
@ -220,48 +307,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
############################ ############################
class TagNameForm(BaseModel): class TagForm(BaseModel):
name: str name: str
class TagFilterForm(TagForm):
skip: Optional[int] = 0 skip: Optional[int] = 0
limit: Optional[int] = 50 limit: Optional[int] = 50
@router.post("/tags", response_model=list[ChatTitleIdResponse]) @router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user) form_data: TagFilterForm, user=Depends(get_verified_user)
): ):
chat_ids = [ chats = Chats.get_chat_list_by_user_id_and_tag_name(
chat_id_tag.chat_id user.id, form_data.name, form_data.skip, form_data.limit
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
) )
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
if len(chats) == 0: if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
return chats return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=list[TagModel])
async def get_all_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetChatById # GetChatById
############################ ############################
@ -309,7 +376,13 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
chat = Chats.get_chat_by_id(id)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
result = Chats.delete_chat_by_id(id) result = Chats.delete_chat_by_id(id)
return result return result
else: else:
if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get( if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
@ -320,16 +393,54 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
chat = Chats.get_chat_by_id(id)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
############################
# GetPinnedStatusById
############################
@router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return chat.pinned
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# PinChatById
############################
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_pinned_by_id(id)
return chat
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# CloneChat # CloneChat
############################ ############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse]) @router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
@ -353,11 +464,25 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
############################ ############################
@router.get("/{id}/archive", response_model=Optional[ChatResponse]) @router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat = Chats.toggle_chat_archive_by_id(id) chat = Chats.toggle_chat_archive_by_id(id)
# Delete tags if chat is archived
if chat.archived:
for tag_id in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
else:
for tag_id in chat.meta.get("tags", []):
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
if tag is None:
log.debug(f"inserting tag: {tag_id}")
tag = Tags.insert_new_tag(tag_id, user.id)
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
else: else:
raise HTTPException( raise HTTPException(
@ -416,6 +541,31 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
) )
############################
# UpdateChatFolderIdById
############################
class ChatFolderIdForm(BaseModel):
folder_id: Optional[str] = None
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
async def update_chat_folder_id_by_id(
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.update_chat_folder_id_by_id_and_user_id(
id, user.id, form_data.folder_id
)
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetChatTagsById # GetChatTagsById
############################ ############################
@ -423,10 +573,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/tags", response_model=list[TagModel]) @router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if tags != None: tags = chat.meta.get("tags", [])
return tags return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -438,22 +588,30 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
############################ ############################
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) @router.post("/{id}/tags", response_model=list[TagModel])
async def add_chat_tag_by_id( async def add_tag_by_id_and_tag_name(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) id: str, form_data: TagForm, user=Depends(get_verified_user)
): ):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower()
if form_data.tag_name not in tags: if tag_id == "none":
tag = Tags.add_tag_to_chat(user.id, form_data)
if tag:
return tag
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
) )
print(tags, tag_id)
if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -465,16 +623,20 @@ async def add_chat_tag_by_id(
############################ ############################
@router.delete("/{id}/tags", response_model=Optional[bool]) @router.delete("/{id}/tags", response_model=list[TagModel])
async def delete_chat_tag_by_id( async def delete_tag_by_id_and_tag_name(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) id: str, form_data: TagForm, user=Depends(get_verified_user)
): ):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( chat = Chats.get_chat_by_id_and_user_id(id, user.id)
form_data.tag_name, id, user.id if chat:
) Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
if result: if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
return result Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -482,16 +644,21 @@ async def delete_chat_tag_by_id(
############################ ############################
# DeleteAllChatTagsById # DeleteAllTagsById
############################ ############################
@router.delete("/{id}/tags/all", response_model=Optional[bool]) @router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
if result: for tag in chat.meta.get("tags", []):
return result if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
return True
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND

View file

@ -0,0 +1,159 @@
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from open_webui.apps.webui.models.users import Users, UserModel
from open_webui.apps.webui.models.feedbacks import (
FeedbackModel,
FeedbackResponse,
FeedbackForm,
Feedbacks,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.utils import get_admin_user, get_verified_user
router = APIRouter()
############################
# GetConfig
############################
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_EVALUATION_ARENA_MODELS": request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS,
"EVALUATION_ARENA_MODELS": request.app.state.config.EVALUATION_ARENA_MODELS,
}
############################
# UpdateConfig
############################
class UpdateConfigForm(BaseModel):
ENABLE_EVALUATION_ARENA_MODELS: Optional[bool] = None
EVALUATION_ARENA_MODELS: Optional[list[dict]] = None
@router.post("/config")
async def update_config(
request: Request,
form_data: UpdateConfigForm,
user=Depends(get_admin_user),
):
config = request.app.state.config
if form_data.ENABLE_EVALUATION_ARENA_MODELS is not None:
config.ENABLE_EVALUATION_ARENA_MODELS = form_data.ENABLE_EVALUATION_ARENA_MODELS
if form_data.EVALUATION_ARENA_MODELS is not None:
config.EVALUATION_ARENA_MODELS = form_data.EVALUATION_ARENA_MODELS
return {
"ENABLE_EVALUATION_ARENA_MODELS": config.ENABLE_EVALUATION_ARENA_MODELS,
"EVALUATION_ARENA_MODELS": config.EVALUATION_ARENA_MODELS,
}
class FeedbackUserResponse(FeedbackResponse):
user: Optional[UserModel] = None
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
return [
FeedbackUserResponse(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
)
for feedback in feedbacks
]
@router.delete("/feedbacks/all")
async def delete_all_feedbacks(user=Depends(get_admin_user)):
success = Feedbacks.delete_all_feedbacks()
return success
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
return [
FeedbackModel(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
)
for feedback in feedbacks
]
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
async def get_feedbacks(user=Depends(get_verified_user)):
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id)
return feedbacks
@router.delete("/feedbacks", response_model=bool)
async def delete_feedbacks(user=Depends(get_verified_user)):
success = Feedbacks.delete_feedbacks_by_user_id(user.id)
return success
@router.post("/feedback", response_model=FeedbackModel)
async def create_feedback(
request: Request,
form_data: FeedbackForm,
user=Depends(get_verified_user),
):
feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data)
if not feedback:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
return feedback
@router.get("/feedback/{id}", response_model=FeedbackModel)
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
if not feedback:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
return feedback
@router.post("/feedback/{id}", response_model=FeedbackModel)
async def update_feedback_by_id(
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
):
feedback = Feedbacks.update_feedback_by_id_and_user_id(
id=id, user_id=user.id, form_data=form_data
)
if not feedback:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
return feedback
@router.delete("/feedback/{id}")
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)):
if user.role == "admin":
success = Feedbacks.delete_feedback_by_id(id=id)
else:
success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
return success

View file

@ -1,14 +1,19 @@
import logging import logging
import os import os
import shutil
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
import mimetypes import mimetypes
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.files import FileForm, FileModel, Files from open_webui.apps.webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR from open_webui.config import UPLOAD_DIR
@ -33,7 +38,7 @@ router = APIRouter()
############################ ############################
@router.post("/") @router.post("/", response_model=FileModelResponse)
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
@ -44,24 +49,19 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
id = str(uuid.uuid4()) id = str(uuid.uuid4())
name = filename name = filename
filename = f"{id}_{filename}" filename = f"{id}_{filename}"
file_path = f"{UPLOAD_DIR}/{filename}" contents, file_path = Storage.upload_file(file.file, filename)
contents = file.file.read() file_item = Files.insert_new_file(
with open(file_path, "wb") as f:
f.write(contents)
f.close()
file = Files.insert_new_file(
user.id, user.id,
FileForm( FileForm(
**{ **{
"id": id, "id": id,
"filename": filename, "filename": filename,
"path": file_path,
"meta": { "meta": {
"name": name, "name": name,
"content_type": file.content_type, "content_type": file.content_type,
"size": len(contents), "size": len(contents),
"path": file_path,
}, },
} }
), ),
@ -69,13 +69,19 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
try: try:
process_file(ProcessFileForm(file_id=id)) process_file(ProcessFileForm(file_id=id))
file = Files.get_file_by_id(id=id) file_item = Files.get_file_by_id(id=id)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error(f"Error processing file: {file.id}") log.error(f"Error processing file: {file_item.id}")
file_item = FileModelResponse(
**{
**file_item.model_dump(),
"error": str(e.detail) if hasattr(e, "detail") else str(e),
}
)
if file: if file_item:
return file return file_item
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -95,7 +101,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
############################ ############################
@router.get("/", response_model=list[FileModel]) @router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user)): async def list_files(user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
files = Files.get_files() files = Files.get_files()
@ -112,27 +118,16 @@ async def list_files(user=Depends(get_verified_user)):
@router.delete("/all") @router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)): async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files() result = Files.delete_all_files()
if result: if result:
folder = f"{UPLOAD_DIR}"
try: try:
# Check if the directory exists Storage.delete_all_files()
if os.path.exists(folder):
# Iterate over all the files and directories in the specified directory
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e: except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}") log.exception(e)
else: log.error(f"Error deleting files")
print(f"The directory {folder} does not exist") raise HTTPException(
except Exception as e: status_code=status.HTTP_400_BAD_REQUEST,
print(f"Failed to process the directory {folder}. Reason: {e}") detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "All files deleted successfully"} return {"message": "All files deleted successfully"}
else: else:
raise HTTPException( raise HTTPException(
@ -213,12 +208,13 @@ async def update_file_data_content_by_id(
############################ ############################
@router.get("/{id}/content", response_model=Optional[FileModel]) @router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
file_path = Path(file.meta["path"]) try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
# Check if the file already exists in the cache # Check if the file already exists in the cache
if file_path.is_file(): if file_path.is_file():
@ -232,6 +228,13 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -239,13 +242,12 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
) )
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel]) @router.get("/{id}/content/html")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
file_path = file.meta.get("path") try:
if file_path: file_path = Storage.get_file(file.path)
file_path = Path(file_path) file_path = Path(file_path)
# Check if the file already exists in the cache # Check if the file already exists in the cache
@ -257,6 +259,42 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/{id}/content/{file_name}")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
file_path = file.path
if file_path:
file_path = Storage.get_file(file_path)
file_path = Path(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else: else:
# File path doesnt exist, return the content as .txt if possible # File path doesnt exist, return the content as .txt if possible
file_content = file.content.get("content", "") file_content = file.content.get("content", "")
@ -289,6 +327,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
result = Files.delete_file_by_id(id) result = Files.delete_file_by_id(id)
if result: if result:
try:
Storage.delete_file(file.filename)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "File deleted successfully"} return {"message": "File deleted successfully"}
else: else:
raise HTTPException( raise HTTPException(

View file

@ -0,0 +1,251 @@
import logging
import os
import shutil
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from open_webui.apps.webui.models.folders import (
FolderForm,
FolderModel,
Folders,
)
from open_webui.apps.webui.models.chats import Chats
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# Get Folders
############################
@router.get("/", response_model=list[FolderModel])
async def get_folders(user=Depends(get_verified_user)):
folders = Folders.get_folders_by_user_id(user.id)
return [
{
**folder.model_dump(),
"items": {
"chats": [
{"title": chat.title, "id": chat.id}
for chat in Chats.get_chats_by_folder_id_and_user_id(
folder.id, user.id
)
]
},
}
for folder in folders
]
############################
# Create Folder
############################
@router.post("/")
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
None, user.id, form_data.name
)
if folder:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
folder = Folders.insert_new_folder(user.id, form_data.name)
return folder
except Exception as e:
log.exception(e)
log.error("Error creating folder")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating folder"),
)
############################
# Get Folders By Id
############################
@router.get("/{id}", response_model=Optional[FolderModel])
async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
return folder
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Update Folder Name By Id
############################
@router.post("/{id}/update")
async def update_folder_name_by_id(
id: str, form_data: FolderForm, user=Depends(get_verified_user)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
folder.parent_id, user.id, form_data.name
)
if existing_folder:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
folder = Folders.update_folder_name_by_id_and_user_id(
id, user.id, form_data.name
)
return folder
except Exception as e:
log.exception(e)
log.error(f"Error updating folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Update Folder Parent Id By Id
############################
class FolderParentIdForm(BaseModel):
parent_id: Optional[str] = None
@router.post("/{id}/update/parent")
async def update_folder_parent_id_by_id(
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
form_data.parent_id, user.id, folder.name
)
if existing_folder:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
folder = Folders.update_folder_parent_id_by_id_and_user_id(
id, user.id, form_data.parent_id
)
return folder
except Exception as e:
log.exception(e)
log.error(f"Error updating folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Update Folder Is Expanded By Id
############################
class FolderIsExpandedForm(BaseModel):
is_expanded: bool
@router.post("/{id}/update/expanded")
async def update_folder_is_expanded_by_id(
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
try:
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
id, user.id, form_data.is_expanded
)
return folder
except Exception as e:
log.exception(e)
log.error(f"Error updating folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Delete Folder By Id
############################
@router.delete("/{id}")
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
try:
result = Folders.delete_folder_by_id_and_user_id(id, user.id)
if result:
return result
else:
raise Exception("Error deleting folder")
except Exception as e:
log.exception(e)
log.error(f"Error deleting folder: {id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)

View file

@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import (
Functions, Functions,
) )
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
from open_webui.config import CACHE_DIR, FUNCTIONS_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user

View file

@ -47,10 +47,43 @@ async def get_knowledge_items(
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
else: else:
return [ knowledge_bases = []
KnowledgeResponse(**knowledge.model_dump())
for knowledge in Knowledges.get_knowledge_items() for knowledge in Knowledges.get_knowledge_items():
]
files = []
if knowledge.data:
files = Files.get_file_metadatas_by_ids(
knowledge.data.get("file_ids", [])
)
# Check if all files exist
if len(files) != len(knowledge.data.get("file_ids", [])):
missing_files = list(
set(knowledge.data.get("file_ids", []))
- set([file.id for file in files])
)
if missing_files:
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
for missing_file in missing_files:
file_ids.remove(missing_file)
data["file_ids"] = file_ids
Knowledges.update_knowledge_by_id(
id=knowledge.id, form_data=KnowledgeUpdateForm(data=data)
)
files = Files.get_file_metadatas_by_ids(file_ids)
knowledge_bases.append(
KnowledgeResponse(
**knowledge.model_dump(),
files=files,
)
)
return knowledge_bases
############################ ############################

View file

@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs from open_webui.utils.tools import get_tools_specs
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter() router = APIRouter()

View file

@ -1,16 +1,14 @@
import site
from pathlib import Path
import black import black
import markdown import markdown
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.env import FONTS_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status from fastapi import APIRouter, Depends, HTTPException, Response, status
from fpdf import FPDF
from pydantic import BaseModel from pydantic import BaseModel
from starlette.responses import FileResponse from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.utils import get_admin_user from open_webui.utils.utils import get_admin_user
router = APIRouter() router = APIRouter()
@ -56,58 +54,19 @@ class ChatForm(BaseModel):
@router.post("/pdf") @router.post("/pdf")
async def download_chat_as_pdf( async def download_chat_as_pdf(
form_data: ChatForm, form_data: ChatTitleMessagesForm,
): ):
global FONTS_DIR try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
pdf = FPDF()
pdf.add_page()
# When running using `pip install` the static directory is in the site packages.
if not FONTS_DIR.exists():
FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts"
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf")
pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf")
pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf")
pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf")
pdf.set_font("NotoSans", size=12)
pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"])
pdf.set_auto_page_break(auto=True, margin=15)
# Adjust the effective page width for multi_cell
effective_page_width = (
pdf.w - 2 * pdf.l_margin - 10
) # Subtracted an additional 10 for extra padding
# Add chat messages
for message in form_data.messages:
role = message["role"]
content = message["content"]
pdf.set_font("NotoSans", "B", size=14) # Bold for the role
pdf.multi_cell(effective_page_width, 10, f"{role.upper()}", 0, "L")
pdf.ln(1) # Extra space between messages
pdf.set_font("NotoSans", size=10) # Regular for content
pdf.multi_cell(effective_page_width, 6, content, 0, "L")
pdf.ln(1.5) # Extra space between messages
# Save the pdf with name .pdf
pdf_bytes = pdf.output()
return Response( return Response(
content=bytes(pdf_bytes), content=pdf_bytes,
media_type="application/pdf", media_type="application/pdf",
headers={"Content-Disposition": "attachment;filename=chat.pdf"}, headers={"Content-Disposition": "attachment;filename=chat.pdf"},
) )
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail=str(e))
@router.get("/db/download") @router.get("/db/download")

View file

@ -8,7 +8,6 @@ import tempfile
from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.tools import Tools from open_webui.apps.webui.models.tools import Tools
from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR
def extract_frontmatter(content): def extract_frontmatter(content):

View file

@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig(
) )
OAUTH_PICTURE_CLAIM = PersistentConfig( OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM", "OAUTH_PICTURE_CLAIM",
"oauth.oidc.avatar_claim", "oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
) )
@ -394,6 +394,33 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
os.environ.get("OAUTH_EMAIL_CLAIM", "email"), os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
) )
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_ROLE_MANAGEMENT",
"oauth.enable_role_mapping",
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
"OAUTH_ROLES_CLAIM",
"oauth.roles_claim",
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
)
OAUTH_ALLOWED_ROLES = PersistentConfig(
"OAUTH_ALLOWED_ROLES",
"oauth.allowed_roles",
[
role.strip()
for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")
],
)
OAUTH_ADMIN_ROLES = PersistentConfig(
"OAUTH_ADMIN_ROLES",
"oauth.admin_roles",
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)
def load_oauth_providers(): def load_oauth_providers():
OAUTH_PROVIDERS.clear() OAUTH_PROVIDERS.clear()
@ -506,6 +533,18 @@ if CUSTOM_NAME:
pass pass
####################################
# STORAGE PROVIDER
####################################
STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
#################################### ####################################
# File Upload DIR # File Upload DIR
#################################### ####################################
@ -521,26 +560,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache" CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# OLLAMA_BASE_URL # OLLAMA_BASE_URL
#################################### ####################################
ENABLE_OLLAMA_API = PersistentConfig( ENABLE_OLLAMA_API = PersistentConfig(
"ENABLE_OLLAMA_API", "ENABLE_OLLAMA_API",
"ollama.enable", "ollama.enable",
@ -728,6 +751,28 @@ USER_PERMISSIONS = PersistentConfig(
}, },
) )
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
"ENABLE_EVALUATION_ARENA_MODELS",
"evaluation.arena.enable",
os.environ.get("ENABLE_EVALUATION_ARENA_MODELS", "True").lower() == "true",
)
EVALUATION_ARENA_MODELS = PersistentConfig(
"EVALUATION_ARENA_MODELS",
"evaluation.arena.models",
[],
)
DEFAULT_ARENA_MODEL = {
"id": "arena-model",
"name": "Arena Model",
"meta": {
"profile_image_url": "/favicon.png",
"description": "Submit your questions to anonymous AI chatbots and vote on the best response.",
"model_ids": None,
},
}
ENABLE_MODEL_FILTER = PersistentConfig( ENABLE_MODEL_FILTER = PersistentConfig(
"ENABLE_MODEL_FILTER", "ENABLE_MODEL_FILTER",
"model_filter.enable", "model_filter.enable",
@ -853,6 +898,12 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
) )
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TAGS_GENERATION_PROMPT_TEMPLATE",
"task.tags.prompt_template",
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
)
ENABLE_SEARCH_QUERY = PersistentConfig( ENABLE_SEARCH_QUERY = PersistentConfig(
"ENABLE_SEARCH_QUERY", "ENABLE_SEARCH_QUERY",
"task.search.enable", "task.search.enable",
@ -886,6 +937,8 @@ CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
# Comma-separated list of header=value pairs # Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS: if CHROMA_HTTP_HEADERS:
@ -901,6 +954,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
# Qdrant
QDRANT_URI = os.environ.get("QDRANT_URI", None)
#################################### ####################################
# Information Retrieval (RAG) # Information Retrieval (RAG)
#################################### ####################################
@ -1011,6 +1067,22 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
RAG_TEXT_SPLITTER = PersistentConfig(
"RAG_TEXT_SPLITTER",
"rag.text_splitter",
os.environ.get("RAG_TEXT_SPLITTER", ""),
)
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
TIKTOKEN_ENCODING_NAME = PersistentConfig(
"TIKTOKEN_ENCODING_NAME",
"rag.tiktoken_encoding_name",
os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"),
)
CHUNK_SIZE = PersistentConfig( CHUNK_SIZE = PersistentConfig(
"CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000"))
) )
@ -1023,7 +1095,7 @@ CHUNK_OVERLAP = PersistentConfig(
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules. DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
<context> <context>
[context] {{CONTEXT}}
</context> </context>
<rules> <rules>
@ -1036,7 +1108,7 @@ DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and r
</rules> </rules>
<user_query> <user_query>
[query] {{QUERY}}
</user_query> </user_query>
""" """
@ -1171,17 +1243,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
) )
####################################
# Transcribe
####################################
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
#################################### ####################################
# Images # Images
#################################### ####################################
@ -1397,6 +1458,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio # Audio
#################################### ####################################
# Transcription
WHISPER_MODEL = PersistentConfig(
"WHISPER_MODEL",
"audio.stt.whisper_model",
os.getenv("WHISPER_MODEL", "base"),
)
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL", "AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url", "audio.stt.openai.api_base_url",
@ -1418,7 +1492,7 @@ AUDIO_STT_ENGINE = PersistentConfig(
AUDIO_STT_MODEL = PersistentConfig( AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL", "AUDIO_STT_MODEL",
"audio.stt.model", "audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"), os.getenv("AUDIO_STT_MODEL", ""),
) )
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(

View file

@ -20,7 +20,9 @@ class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
return super().__str__() return super().__str__()
DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}" DEFAULT = (
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
)
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot." DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
@ -106,6 +108,7 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}" DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation" TITLE_GENERATION = "title_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation" EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation" QUERY_GENERATION = "query_generation"
FUNCTION_CALLING = "function_calling" FUNCTION_CALLING = "function_calling"

View file

@ -230,6 +230,8 @@ if FROM_INIT_PY:
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts")) FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
@ -361,6 +363,20 @@ else:
except Exception: except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300 AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
)
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
#################################### ####################################
# OFFLINE_MODE # OFFLINE_MODE
#################################### ####################################

View file

@ -1,4 +1,4 @@
import base64 import asyncio
import inspect import inspect
import json import json
import logging import logging
@ -7,20 +7,38 @@ import os
import shutil import shutil
import sys import sys
import time import time
import uuid import random
import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional
import aiohttp import aiohttp
import requests import requests
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import ( from open_webui.apps.ollama.main import (
app as ollama_app, app as ollama_app,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion, generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
GenerateChatCompletionForm, GenerateChatCompletionForm,
) )
from open_webui.apps.openai.main import ( from open_webui.apps.openai.main import (
@ -28,38 +46,24 @@ from open_webui.apps.openai.main import (
generate_chat_completion as generate_openai_chat_completion, generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models, get_all_models as get_openai_models,
) )
from open_webui.apps.retrieval.main import app as retrieval_app from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import ( from open_webui.apps.socket.main import (
app as socket_app, app as socket_app,
periodic_usage_pool_cleanup, periodic_usage_pool_cleanup,
get_event_call, get_event_call,
get_event_emitter, get_event_emitter,
) )
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import ( from open_webui.apps.webui.main import (
app as webui_app, app as webui_app,
generate_function_chat_completion, generate_function_chat_completion,
get_pipe_models, get_all_models as get_open_webui_models,
) )
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from open_webui.config import ( from open_webui.config import (
CACHE_DIR, CACHE_DIR,
CORS_ALLOW_ORIGIN, CORS_ALLOW_ORIGIN,
@ -67,13 +71,11 @@ from open_webui.config import (
ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
ENABLE_OAUTH_SIGNUP,
ENABLE_OLLAMA_API, ENABLE_OLLAMA_API,
ENABLE_OPENAI_API, ENABLE_OPENAI_API,
ENV, ENV,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS, OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY, ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
@ -81,15 +83,15 @@ from open_webui.config import (
TASK_MODEL, TASK_MODEL,
TASK_MODEL_EXTERNAL, TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_NAME, WEBUI_NAME,
AppConfig, AppConfig,
run_migrations,
reset_config, reset_config,
) )
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES from open_webui.constants import TASKS
from open_webui.env import ( from open_webui.env import (
CHANGELOG, CHANGELOG,
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
@ -104,63 +106,39 @@ from open_webui.env import (
RESET_CONFIG_ON_START, RESET_CONFIG_ON_START,
OFFLINE_MODE, OFFLINE_MODE,
) )
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse, Response, StreamingResponse
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.misc import ( from open_webui.utils.misc import (
add_or_update_system_message, add_or_update_system_message,
get_last_user_message, get_last_user_message,
parse_duration,
prepend_to_first_user_message_content, prepend_to_first_user_message_content,
) )
from open_webui.utils.task import ( from open_webui.utils.oauth import oauth_manager
moa_response_generation_template,
search_query_generation_template,
title_generation_template,
tools_function_calling_generation_template,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.utils import (
create_token,
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_password_hash,
get_verified_user,
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import ( from open_webui.utils.response import (
convert_response_ollama_to_openai, convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai, convert_streaming_response_ollama_to_openai,
) )
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import (
moa_response_generation_template,
tags_generation_template,
search_query_generation_template,
emoji_generation_template,
title_generation_template,
tools_function_calling_generation_template,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.utils import (
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_verified_user,
)
if SAFE_MODE: if SAFE_MODE:
print("SAFE MODE ENABLED") print("SAFE MODE ENABLED")
Functions.deactivate_all_functions() Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
@ -217,10 +195,10 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
) )
@ -461,9 +439,20 @@ async def chat_completion_tools_handler(
tool_function_params = result.get("parameters", {}) tool_function_params = result.get("parameters", {})
try: try:
tool_output = await tools[tool_function_name]["callable"]( required_params = (
**tool_function_params tools[tool_function_name]
.get("spec", {})
.get("parameters", {})
.get("required", [])
) )
tool_function = tools[tool_function_name]["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
}
tool_output = await tool_function(**tool_function_params)
except Exception as e: except Exception as e:
tool_output = str(e) tool_output = str(e)
@ -578,7 +567,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
} }
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client
# Initalize contexts and citation # Initialize contexts and citation
data_items = [] data_items = []
contexts = [] contexts = []
citations = [] citations = []
@ -690,6 +679,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)
################################## ##################################
# #
# Pipeline Middleware # Pipeline Middleware
@ -927,12 +917,10 @@ webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
async def get_all_models(): async def get_all_models():
# TODO: Optimize this function # TODO: Optimize this function
pipe_models = [] open_webui_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API: if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models() openai_models = await get_openai_models()
openai_models = openai_models["data"] openai_models = openai_models["data"]
@ -951,7 +939,13 @@ async def get_all_models():
for model in ollama_models["models"] for model in ollama_models["models"]
] ]
models = pipe_models + openai_models + ollama_models open_webui_models = await get_open_webui_models()
models = open_webui_models + openai_models + ollama_models
# If there are no models, return an empty list
if len([model for model in models if model["owned_by"] != "arena"]) == 0:
return []
global_action_ids = [ global_action_ids = [
function.id for function in Functions.get_global_action_functions() function.id for function in Functions.get_global_action_functions()
@ -990,11 +984,13 @@ async def get_all_models():
owned_by = model["owned_by"] owned_by = model["owned_by"]
if "pipe" in model: if "pipe" in model:
pipe = model["pipe"] pipe = model["pipe"]
if "info" in model and "meta" in model["info"]:
action_ids.extend(model["info"]["meta"].get("actionIds", []))
break break
if custom_model.meta:
meta = custom_model.meta.model_dump()
if "actionIds" in meta:
action_ids.extend(meta["actionIds"])
models.append( models.append(
{ {
"id": custom_model.id, "id": custom_model.id,
@ -1097,7 +1093,9 @@ async def get_models(user=Depends(get_verified_user)):
@app.post("/api/chat/completions") @app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): async def generate_chat_completions(
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
):
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in app.state.MODELS:
@ -1106,7 +1104,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
detail="Model not found", detail="Model not found",
) )
if app.state.config.ENABLE_MODEL_FILTER: if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -1114,13 +1112,62 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena"
and not model.get("info", {}).get("meta", {}).get("hidden", False)
and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena"
and not model.get("info", {}).get("meta", {}).get("hidden", False)
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completions(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
)
else:
return {
**(
await generate_chat_completions(form_data, user, bypass_filter=True)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint # Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data) form_data = convert_payload_openai_to_ollama(form_data)
form_data = GenerateChatCompletionForm(**form_data) form_data = GenerateChatCompletionForm(**form_data)
response = await generate_ollama_chat_completion(form_data=form_data, user=user) response = await generate_ollama_chat_completion(
form_data=form_data, user=user, bypass_filter=True
)
if form_data.stream: if form_data.stream:
response.headers["content-type"] = "text/event-stream" response.headers["content-type"] = "text/event-stream"
return StreamingResponse( return StreamingResponse(
@ -1425,6 +1472,7 @@ async def get_task_config(user=Depends(get_verified_user)):
"TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@ -1435,6 +1483,7 @@ class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str] TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str TITLE_GENERATION_PROMPT_TEMPLATE: str
TAGS_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
ENABLE_SEARCH_QUERY: bool ENABLE_SEARCH_QUERY: bool
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@ -1447,6 +1496,10 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE form_data.TITLE_GENERATION_PROMPT_TEMPLATE
) )
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
) )
@ -1459,6 +1512,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@ -1486,7 +1540,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else: else:
template = """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles: Examples of titles:
📉 Stock Market Trends 📉 Stock Market Trends
@ -1496,11 +1550,13 @@ Remote Work Productivity Tips
Artificial Intelligence in Healthcare Artificial Intelligence in Healthcare
🎮 Video Game Development Insights 🎮 Video Game Development Insights
Prompt: {{prompt:middletruncate:8000}}""" <chat_history>
{{MESSAGES:END:2}}
</chat_history>"""
content = title_generation_template( content = title_generation_template(
template, template,
form_data["prompt"], form_data["messages"],
{ {
"name": user.name, "name": user.name,
"location": user.info.get("location") if user.info else None, "location": user.info.get("location") if user.info else None,
@ -1543,6 +1599,75 @@ Prompt: {{prompt:middletruncate:8000}}"""
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tags/completions")
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
print("generate_chat_tags")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id)
print(task_model_id)
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
else:
template = """### Task:
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
### Guidelines:
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
- Use the chat's primary language; default to English if multilingual
- Prioritize accuracy over specificity
### Output:
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
content = tags_generation_template(
template, form_data["messages"], {"name": user.name}
)
print("content", content)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
}
log.debug(payload)
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions") @app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query") print("generate_search_query")
@ -1643,7 +1768,7 @@ Your task is to reflect the speaker's likely facial expression through a fitting
Message: """{{prompt}}""" Message: """{{prompt}}"""
''' '''
content = title_generation_template( content = emoji_generation_template(
template, template,
form_data["prompt"], form_data["prompt"],
{ {
@ -2233,20 +2358,6 @@ async def get_app_latest_release_version():
# OAuth Login & Callback # OAuth Login & Callback
############################ ############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
# SessionMiddleware is used by authlib for oauth # SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0: if len(OAUTH_PROVIDERS) > 0:
app.add_middleware( app.add_middleware(
@ -2260,16 +2371,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login") @app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request): async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS: return await oauth_manager.handle_login(provider, request)
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = oauth.create_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
# OAuth login logic is as follows: # OAuth login logic is as follows:
@ -2277,119 +2379,10 @@ async def oauth_login(provider: str, request: Request):
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth # 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
# - This is considered insecure in general, as OAuth providers do not always verify email addresses # - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken # - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback") @app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response): async def oauth_callback(provider: str, request: Request, response: Response):
if provider not in OAUTH_PROVIDERS: return await oauth_manager.handle_callback(provider, request, response)
raise HTTPException(404)
client = oauth.create_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}")
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = (
"admin"
if Users.get_num_users() == 0
else webui_app.state.config.DEFAULT_USER_ROLE
)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
@app.get("/manifest.json") @app.get("/manifest.json")

View file

@ -0,0 +1,151 @@
"""Migrate tags
Revision ID: 1af9b942657b
Revises: 242a2047eae0
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
revision = "1af9b942657b"
down_revision = "242a2047eae0"
branch_labels = None
depends_on = None
def upgrade():
# Setup an inspection on the existing table to avoid issues
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
# Clean up potential leftover temp table from previous failures
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
# Check if the 'tag' table exists
tables = inspector.get_table_names()
# Step 1: Modify Tag table using batch mode for SQLite support
if "tag" in tables:
# Get the current columns in the 'tag' table
columns = [col["name"] for col in inspector.get_columns("tag")]
# Get any existing unique constraints on the 'tag' table
current_constraints = inspector.get_unique_constraints("tag")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Check if the unique constraint already exists
if not any(
constraint["name"] == "uq_id_user_id"
for constraint in current_constraints
):
# Create unique constraint if it doesn't exist
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
# Check if the 'data' column exists before trying to drop it
if "data" in columns:
batch_op.drop_column("data")
# Check if the 'meta' column needs to be created
if "meta" not in columns:
# Add the 'meta' column if it doesn't already exist
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
tag = table(
"tag",
column("id", sa.String()),
column("name", sa.String()),
column("user_id", sa.String()),
column("meta", sa.JSON()),
)
# Step 2: Migrate tags
conn = op.get_bind()
result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
tag_updates = {}
for row in result:
new_id = row.name.replace(" ", "_").lower()
tag_updates[row.id] = new_id
for tag_id, new_tag_id in tag_updates.items():
print(f"Updating tag {tag_id} to {new_tag_id}")
if new_tag_id == "pinned":
# delete tag
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
# Check if the new_tag_id already exists in the database
existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id)
existing_tag_result = conn.execute(existing_tag_query).fetchone()
if existing_tag_result:
# Handle duplicate case: the new_tag_id already exists
print(
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
)
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
update_stmt = sa.update(tag).where(tag.c.id == tag_id)
update_stmt = update_stmt.values(id=new_tag_id)
conn.execute(update_stmt)
# Add columns `pinned` and `meta` to 'chat'
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
op.add_column(
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
)
chatidtag = table(
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
)
chat = table(
"chat",
column("id", sa.String()),
column("pinned", sa.Boolean()),
column("meta", sa.JSON()),
)
# Fetch existing tags
conn = op.get_bind()
result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
chat_updates = {}
for row in result:
chat_id = row.chat_id
tag_name = row.tag_name.replace(" ", "_").lower()
if tag_name == "pinned":
# Specifically handle 'pinned' tag
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": True, "meta": {}}
else:
chat_updates[chat_id]["pinned"] = True
else:
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
else:
tags = chat_updates[chat_id]["meta"].get("tags", [])
tags.append(tag_name)
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
# Update chats based on accumulated changes
for chat_id, updates in chat_updates.items():
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
update_stmt = update_stmt.values(
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
)
conn.execute(update_stmt)
pass
def downgrade():
pass

View file

@ -19,17 +19,41 @@ depends_on = None
def upgrade(): def upgrade():
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = inspector.get_columns("chat")
column_dict = {col["name"]: col for col in columns}
chat_column = column_dict.get("chat")
old_chat_exists = "old_chat" in column_dict
if chat_column:
if isinstance(chat_column["type"], sa.Text):
print("Converting 'chat' column to JSON")
if old_chat_exists:
print("Dropping old 'old_chat' column")
op.drop_column("chat", "old_chat")
# Step 1: Rename current 'chat' column to 'old_chat' # Step 1: Rename current 'chat' column to 'old_chat'
op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text) print("Renaming 'chat' column to 'old_chat'")
op.alter_column(
"chat", "chat", new_column_name="old_chat", existing_type=sa.Text()
)
# Step 2: Add new 'chat' column of type JSON # Step 2: Add new 'chat' column of type JSON
print("Adding new 'chat' column of type JSON")
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True)) op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
else:
# If the column is already JSON, no need to do anything
pass
# Step 3: Migrate data from 'old_chat' to 'chat' # Step 3: Migrate data from 'old_chat' to 'chat'
chat_table = table( chat_table = table(
"chat", "chat",
sa.Column("id", sa.String, primary_key=True), sa.Column("id", sa.String(), primary_key=True),
sa.Column("old_chat", sa.Text), sa.Column("old_chat", sa.Text()),
sa.Column("chat", sa.JSON()), sa.Column("chat", sa.JSON()),
) )
@ -50,6 +74,7 @@ def upgrade():
) )
# Step 4: Drop 'old_chat' column # Step 4: Drop 'old_chat' column
print("Dropping 'old_chat' column")
op.drop_column("chat", "old_chat") op.drop_column("chat", "old_chat")
@ -60,7 +85,7 @@ def downgrade():
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat' # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
chat_table = table( chat_table = table(
"chat", "chat",
sa.Column("id", sa.String, primary_key=True), sa.Column("id", sa.String(), primary_key=True),
sa.Column("chat", sa.JSON()), sa.Column("chat", sa.JSON()),
sa.Column("old_chat", sa.Text()), sa.Column("old_chat", sa.Text()),
) )
@ -79,4 +104,4 @@ def downgrade():
op.drop_column("chat", "chat") op.drop_column("chat", "chat")
# Step 4: Rename 'old_chat' back to 'chat' # Step 4: Rename 'old_chat' back to 'chat'
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text) op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text())

View file

@ -0,0 +1,81 @@
"""Update tags
Revision ID: 3ab32c4b8f59
Revises: 1af9b942657b
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
revision = "3ab32c4b8f59"
down_revision = "1af9b942657b"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
# Inspecting the 'tag' table constraints and structure
existing_pk = inspector.get_pk_constraint("tag")
unique_constraints = inspector.get_unique_constraints("tag")
existing_indexes = inspector.get_indexes("tag")
print(f"Primary Key: {existing_pk}")
print(f"Unique Constraints: {unique_constraints}")
print(f"Indexes: {existing_indexes}")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Drop existing primary key constraint if it exists
if existing_pk and existing_pk.get("constrained_columns"):
pk_name = existing_pk.get("name")
if pk_name:
print(f"Dropping primary key constraint: {pk_name}")
batch_op.drop_constraint(pk_name, type_="primary")
# Now create the new primary key with the combination of 'id' and 'user_id'
print("Creating new primary key with 'id' and 'user_id'.")
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
# Drop unique constraints that could conflict with the new primary key
for constraint in unique_constraints:
if (
constraint["name"] == "uq_id_user_id"
): # Adjust this name according to what is actually returned by the inspector
print(f"Dropping unique constraint: {constraint['name']}")
batch_op.drop_constraint(constraint["name"], type_="unique")
for index in existing_indexes:
if index["unique"]:
if not any(
constraint["name"] == index["name"]
for constraint in unique_constraints
):
# You are attempting to drop unique indexes
print(f"Dropping unique index: {index['name']}")
batch_op.drop_index(index["name"])
def downgrade():
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
current_pk = inspector.get_pk_constraint("tag")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Drop the current primary key first, if it matches the one we know we added in upgrade
if current_pk and "pk_id_user_id" == current_pk.get("name"):
batch_op.drop_constraint("pk_id_user_id", type_="primary")
# Restore the original primary key
batch_op.create_primary_key("pk_id", ["id"])
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])

View file

@ -0,0 +1,67 @@
"""Update folder table and change DateTime to BigInteger for timestamp fields
Revision ID: 4ace53fd72c8
Revises: af906e964978
Create Date: 2024-10-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "4ace53fd72c8"
down_revision = "af906e964978"
branch_labels = None
depends_on = None
def upgrade():
# Perform safe alterations using batch operation
with op.batch_alter_table("folder", schema=None) as batch_op:
# Step 1: Remove server defaults for created_at and updated_at
batch_op.alter_column(
"created_at",
server_default=None, # Removing server default
)
batch_op.alter_column(
"updated_at",
server_default=None, # Removing server default
)
# Step 2: Change the column types to BigInteger for created_at
batch_op.alter_column(
"created_at",
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL
)
# Change the column types to BigInteger for updated_at
batch_op.alter_column(
"updated_at",
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL
)
def downgrade():
# Downgrade: Convert columns back to DateTime and restore defaults
with op.batch_alter_table("folder", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade
)
batch_op.alter_column(
"updated_at",
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade
onupdate=sa.func.now(), # Restoring onupdate behavior if it was there
)

View file

@ -0,0 +1,51 @@
"""Add feedback table
Revision ID: af906e964978
Revises: c29facfe716b
Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers, used by Alembic.
revision = "af906e964978"
down_revision = "c29facfe716b"
branch_labels = None
depends_on = None
def upgrade():
# ### Create feedback table ###
op.create_table(
"feedback",
sa.Column(
"id", sa.Text(), primary_key=True
), # Unique identifier for each feedback (TEXT type)
sa.Column(
"user_id", sa.Text(), nullable=True
), # ID of the user providing the feedback (TEXT type)
sa.Column(
"version", sa.BigInteger(), default=0
), # Version of feedback (BIGINT type)
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
sa.Column(
"meta", sa.JSON(), nullable=True
), # Metadata for feedback (JSON type)
sa.Column(
"snapshot", sa.JSON(), nullable=True
), # snapshot data for feedback (JSON type)
sa.Column(
"created_at", sa.BigInteger(), nullable=False
), # Feedback creation timestamp (BIGINT representing epoch)
sa.Column(
"updated_at", sa.BigInteger(), nullable=False
), # Feedback update timestamp (BIGINT representing epoch)
)
def downgrade():
# ### Drop feedback table ###
op.drop_table("feedback")

View file

@ -0,0 +1,79 @@
"""Update file table path
Revision ID: c29facfe716b
Revises: c69f45358db4
Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
import json
from sqlalchemy.sql import table, column
from sqlalchemy import String, Text, JSON, and_
revision = "c29facfe716b"
down_revision = "c69f45358db4"
branch_labels = None
depends_on = None
def upgrade():
# 1. Add the `path` column to the "file" table.
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
# Use Alembic's default batch_op for dialect compatibility.
with op.batch_alter_table("file", schema=None) as batch_op:
batch_op.alter_column(
"meta",
type_=sa.JSON(),
existing_type=sa.Text(),
existing_nullable=True,
nullable=True,
postgresql_using="meta::json",
)
# 3. Migrate legacy data from `meta` JSONField
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
# We will use SQLAlchemy core bindings to ensure safety across different databases.
file_table = table(
"file", column("id", String), column("meta", JSON), column("path", Text)
)
# Create connection to the database
connection = op.get_bind()
# Get the rows where `meta` has a path and `path` column is null (new column)
# Loop through each row in the result set to update the path
results = connection.execute(
sa.select(file_table.c.id, file_table.c.meta).where(
and_(file_table.c.path.is_(None), file_table.c.meta.isnot(None))
)
).fetchall()
# Iterate over each row to extract and update the `path` from `meta` column
for row in results:
if "path" in row.meta:
# Extract the `path` field from the `meta` JSON
path = row.meta.get("path")
# Update the `file` table with the new `path` value
connection.execute(
file_table.update()
.where(file_table.c.id == row.id)
.values({"path": path})
)
def downgrade():
# 1. Remove the `path` column
op.drop_column("file", "path")
# 2. Revert the `meta` column back to Text/JSONField
with op.batch_alter_table("file", schema=None) as batch_op:
batch_op.alter_column(
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
)

View file

@ -0,0 +1,50 @@
"""Add folder table
Revision ID: c69f45358db4
Revises: 3ab32c4b8f59
Create Date: 2024-10-16 02:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
revision = "c69f45358db4"
down_revision = "3ab32c4b8f59"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"folder",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("parent_id", sa.Text(), nullable=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("items", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
),
sa.Column(
"updated_at",
sa.DateTime(),
nullable=False,
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
sa.PrimaryKeyConstraint("id", "user_id"),
)
op.add_column(
"chat",
sa.Column("folder_id", sa.Text(), nullable=True),
)
def downgrade():
op.drop_column("chat", "folder_id")
op.drop_table("folder")

View file

@ -0,0 +1,319 @@
/* HTML and Body */
@font-face {
font-family: 'NotoSans';
src: url('fonts/NotoSans-Variable.ttf');
}
@font-face {
font-family: 'NotoSansJP';
src: url('fonts/NotoSansJP-Variable.ttf');
}
@font-face {
font-family: 'NotoSansKR';
src: url('fonts/NotoSansKR-Variable.ttf');
}
@font-face {
font-family: 'NotoSansSC';
src: url('fonts/NotoSansSC-Variable.ttf');
}
@font-face {
font-family: 'NotoSansSC-Regular';
src: url('fonts/NotoSansSC-Regular.ttf');
}
html {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR',
'NotoSansSC', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto,
'Helvetica Neue', Arial, sans-serif;
font-size: 14px; /* Default font size */
line-height: 1.5;
}
*,
*::before,
*::after {
box-sizing: inherit;
}
body {
margin: 0;
color: #212529;
background-color: #fff;
width: auto;
}
/* Typography */
h1,
h2,
h3,
h4,
h5,
h6 {
font-weight: 500;
margin: 0;
}
h1 {
font-size: 2.5rem;
}
h2 {
font-size: 2rem;
}
h3 {
font-size: 1.75rem;
}
h4 {
font-size: 1.5rem;
}
h5 {
font-size: 1.25rem;
}
h6 {
font-size: 1rem;
}
p {
margin-top: 0;
margin-bottom: 1rem;
}
/* Grid System */
.container {
width: 100%;
padding-right: 15px;
padding-left: 15px;
margin-right: auto;
margin-left: auto;
}
/* Utilities */
.text-center {
text-align: center;
}
/* Additional Text Utilities */
.text-muted {
color: #6c757d; /* Muted text color */
}
/* Small Text */
small {
font-size: 80%; /* Smaller font size relative to the base */
color: #6c757d; /* Lighter text color for secondary information */
margin-bottom: 0;
margin-top: 0;
}
/* Strong Element Styles */
strong {
font-weight: bolder; /* Ensures the text is bold */
color: inherit; /* Inherits the color from its parent element */
}
/* link */
a {
color: #007bff;
text-decoration: none;
background-color: transparent;
}
a:hover {
color: #0056b3;
text-decoration: underline;
}
/* General styles for lists */
ol,
ul,
li {
padding-left: 40px; /* Increase padding to move bullet points to the right */
margin-left: 20px; /* Indent lists from the left */
}
/* Ordered list styles */
ol {
list-style-type: decimal; /* Use numbers for ordered lists */
margin-bottom: 10px; /* Space after each list */
}
ol li {
margin-bottom: 0.5rem; /* Space between ordered list items */
}
/* Unordered list styles */
ul {
list-style-type: disc; /* Use bullets for unordered lists */
margin-bottom: 10px; /* Space after each list */
}
ul li {
margin-bottom: 0.5rem; /* Space between unordered list items */
}
/* List item styles */
li {
margin-bottom: 5px; /* Space between list items */
line-height: 1.5; /* Line height for better readability */
}
/* Nested lists */
ol ol,
ol ul,
ul ol,
ul ul {
padding-left: 20px;
margin-left: 30px; /* Further indent nested lists */
margin-bottom: 0; /* Remove extra margin at the bottom of nested lists */
}
/* Code blocks */
pre {
background-color: #f4f4f4;
padding: 10px;
overflow-x: auto;
max-width: 100%; /* Ensure it doesn't overflow the page */
width: 80%; /* Set a specific width for a container-like appearance */
margin: 0 1em; /* Center the pre block */
box-sizing: border-box; /* Include padding in the width */
border: 1px solid #ccc; /* Optional: Add a border for better definition */
border-radius: 4px; /* Optional: Add rounded corners */
}
code {
font-family: 'Courier New', Courier, monospace;
background-color: #f4f4f4;
padding: 2px 4px;
border-radius: 4px;
box-sizing: border-box; /* Include padding in the width */
}
.message {
margin-top: 8px;
margin-bottom: 8px;
max-width: 100%;
overflow-wrap: break-word;
}
/* Table Styles */
table {
width: 100%;
margin-bottom: 1rem;
color: #212529;
border-collapse: collapse; /* Removes the space between borders */
}
th,
td {
margin: 0;
padding: 0.75rem;
vertical-align: top;
border-top: 1px solid #dee2e6;
}
thead th {
vertical-align: bottom;
border-bottom: 2px solid #dee2e6;
}
tbody + tbody {
border-top: 2px solid #dee2e6;
}
/* markdown-section styles */
.markdown-section blockquote,
.markdown-section h1,
.markdown-section h2,
.markdown-section h3,
.markdown-section h4,
.markdown-section h5,
.markdown-section h6,
.markdown-section p,
.markdown-section pre,
.markdown-section table,
.markdown-section ul {
/* Give most block elements margin top and bottom */
margin-top: 1rem;
}
/* Remove top margin if it's the first child */
.markdown-section blockquote:first-child,
.markdown-section h1:first-child,
.markdown-section h2:first-child,
.markdown-section h3:first-child,
.markdown-section h4:first-child,
.markdown-section h5:first-child,
.markdown-section h6:first-child,
.markdown-section p:first-child,
.markdown-section pre:first-child,
.markdown-section table:first-child,
.markdown-section ul:first-child {
margin-top: 0;
}
/* Remove top margin of <ul> following a <p> */
.markdown-section p + ul {
margin-top: 0;
}
/* Remove bottom margin of <p> if it is followed by a <ul> */
/* Note: :has is not supported in CSS, so you would need JavaScript for this behavior */
.markdown-section p {
margin-bottom: 0;
}
/* Add a rule to reset margin-bottom for <p> not followed by <ul> */
.markdown-section p + ul {
margin-top: 0;
}
/* List item styles */
.markdown-section li {
padding: 2px;
}
.markdown-section li p {
margin-bottom: 0;
padding: 0;
}
/* Avoid margins for nested lists */
.markdown-section li > ul {
margin-top: 0;
margin-bottom: 0;
}
/* Table styles */
.markdown-section table {
width: 100%;
border-collapse: collapse;
margin: 1rem 0;
}
.markdown-section th,
.markdown-section td {
border: 1px solid #ddd;
padding: 0.5rem;
text-align: left;
}
.markdown-section th {
background-color: #f2f2f2;
}
.markdown-section pre {
padding: 10px;
margin: 10px;
}
.markdown-section pre code {
position: relative;
color: rgb(172, 0, 95);
}

Binary file not shown.

View file

@ -0,0 +1,164 @@
import os
import boto3
from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
STORAGE_PROVIDER,
S3_ACCESS_KEY_ID,
S3_SECRET_ACCESS_KEY,
S3_BUCKET_NAME,
S3_REGION_NAME,
S3_ENDPOINT_URL,
UPLOAD_DIR,
)
import boto3
from botocore.exceptions import ClientError
from typing import BinaryIO, Tuple, Optional
class StorageProvider:
def __init__(self, provider: Optional[str] = None):
self.storage_provider: str = provider or STORAGE_PROVIDER
self.s3_client = None
self.s3_bucket_name: Optional[str] = None
if self.storage_provider == "s3":
self._initialize_s3()
def _initialize_s3(self) -> None:
"""Initializes the S3 client and bucket name if using S3 storage."""
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
def _upload_to_s3(self, file_path: str, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
return open(file_path, "rb").read(), file_path
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to local storage."""
file_path = f"{UPLOAD_DIR}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
return contents, file_path
def _get_file_from_s3(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def _get_file_from_local(self, file_path: str) -> str:
"""Handles downloading of the file from local storage."""
return file_path
def _delete_from_s3(self, filename: str) -> None:
"""Handles deletion of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
def _delete_from_local(self, filename: str) -> None:
"""Handles deletion of the file from local storage."""
file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path):
os.remove(file_path)
else:
print(f"File {file_path} not found in local storage.")
def _delete_all_from_s3(self) -> None:
"""Handles deletion of all files from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
def _delete_all_from_local(self) -> None:
"""Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR):
file_path = os.path.join(UPLOAD_DIR, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"Directory {UPLOAD_DIR} not found in local storage.")
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
contents, file_path = self._upload_to_local(contents, filename)
if self.storage_provider == "s3":
return self._upload_to_s3(file_path, filename)
return contents, file_path
def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path."""
if self.storage_provider == "s3":
return self._get_file_from_s3(file_path)
return self._get_file_from_local(file_path)
def delete_file(self, filename: str) -> None:
"""Deletes a file either from S3 or the local file system."""
if self.storage_provider == "s3":
self._delete_from_s3(filename)
# Always delete from local storage
self._delete_from_local(filename)
def delete_all_files(self) -> None:
"""Deletes all files from the storage."""
if self.storage_provider == "s3":
self._delete_all_from_s3()
# Always delete from local storage
self._delete_all_from_local()
Storage = StorageProvider(provider=STORAGE_PROVIDER)

View file

@ -0,0 +1,261 @@
import base64
import logging
import mimetypes
import uuid
import aiohttp
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import (
HTTPException,
status,
)
from starlette.responses import RedirectResponse
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.users import Users
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
WEBHOOK_URL,
JWT_EXPIRES_IN,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
from open_webui.utils.misc import parse_duration
from open_webui.utils.utils import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__)
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
self.oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
return "admin"
if not user and Users.get_num_users() == 0:
# If there are no users, assign the role "admin", as the first user will be an admin
return "admin"
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found
# Next block extracts the roles from the user data, accepting nested claims of any depth
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
role = "admin"
break
else:
if not user:
# If role management is disabled, use the default role for new users
role = auth_manager_config.DEFAULT_USER_ROLE
else:
# If role management is disabled, use the existing role for existing users
role = user.role
return role
async def handle_login(self, provider, request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = self.get_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, provider, request, response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = self.get_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
if not user_data:
user_data: UserInfo = await client.userinfo(token=token)
if not user_data:
log.warning(f"OAuth callback failed, user data is missing: {token}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(
user_data.get("email", "").lower()
)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(
f"Error downloading profile image '{picture_url}': {e}"
)
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if auth_manager_config.WEBHOOK_URL:
post_webhook(
auth_manager_config.WEBHOOK_URL,
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(
user.name
),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
oauth_manager = OAuthManager()

View file

@ -88,6 +88,53 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
return form_data return form_data
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
ollama_messages = []
for message in messages:
# Initialize the new message structure with the role
new_message = {"role": message["role"]}
content = message.get("content", [])
# Check if the content is a string (just a simple message)
if isinstance(content, str):
# If the content is a string, it's pure text
new_message["content"] = content
else:
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
content_text = ""
images = []
# Iterate through the list of content items
for item in content:
# Check if it's a text type
if item.get("type") == "text":
content_text += item.get("text", "")
# Check if it's an image URL type
elif item.get("type") == "image_url":
img_url = item.get("image_url", {}).get("url", "")
if img_url:
# If the image url starts with data:, it's a base64 image and should be trimmed
if img_url.startswith("data:"):
img_url = img_url.split(",")[-1]
images.append(img_url)
# Add content text (if any)
if content_text:
new_message["content"] = content_text.strip()
# Add images (if any)
if images:
new_message["images"] = images
# Append the new formatted message to the result
ollama_messages.append(new_message)
return ollama_messages
def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
""" """
Converts a payload formatted for OpenAI's API to be compatible with Ollama's API endpoint for chat completions. Converts a payload formatted for OpenAI's API to be compatible with Ollama's API endpoint for chat completions.
@ -102,7 +149,9 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
# Mapping basic model and message details # Mapping basic model and message details
ollama_payload["model"] = openai_payload.get("model") ollama_payload["model"] = openai_payload.get("model")
ollama_payload["messages"] = openai_payload.get("messages") ollama_payload["messages"] = convert_messages_openai_to_ollama(
openai_payload.get("messages")
)
ollama_payload["stream"] = openai_payload.get("stream", False) ollama_payload["stream"] = openai_payload.get("stream", False)
# If there are advanced parameters in the payload, format them in Ollama's options field # If there are advanced parameters in the payload, format them in Ollama's options field

View file

@ -0,0 +1,139 @@
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Dict, Any, List
from markdown import markdown
import site
from fpdf import FPDF
from open_webui.env import STATIC_DIR, FONTS_DIR
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
class PDFGenerator:
"""
Description:
The `PDFGenerator` class is designed to create PDF documents from chat messages.
The process involves transforming markdown content into HTML and then into a PDF format
Attributes:
- `form_data`: An instance of `ChatTitleMessagesForm` containing title and messages.
"""
def __init__(self, form_data: ChatTitleMessagesForm):
self.html_body = None
self.messages_html = None
self.form_data = form_data
self.css = Path(STATIC_DIR / "assets" / "pdf-style.css").read_text()
def format_timestamp(self, timestamp: float) -> str:
"""Convert a UNIX timestamp to a formatted date string."""
try:
date_time = datetime.fromtimestamp(timestamp)
return date_time.strftime("%Y-%m-%d, %H:%M:%S")
except (ValueError, TypeError) as e:
# Log the error if necessary
return ""
def _build_html_message(self, message: Dict[str, Any]) -> str:
"""Build HTML for a single message."""
role = message.get("role", "user")
content = message.get("content", "")
timestamp = message.get("timestamp")
model = message.get("model") if role == "assistant" else ""
date_str = self.format_timestamp(timestamp) if timestamp else ""
# extends pymdownx extension to convert markdown to html.
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
html_content = markdown(content, extensions=["pymdownx.extra"])
html_message = f"""
<div class="message">
<small> {date_str} </small>
<div>
<h2>
<strong>{role.title()}</strong>
<small class="text-muted">{model}</small>
</h2>
</div>
<div class="markdown-section">
{html_content}
</div>
</div>
"""
return html_message
def _generate_html_body(self) -> str:
"""Generate the full HTML body for the PDF."""
return f"""
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body>
<div class="container">
<div class="text-center">
<h1>{self.form_data.title}</h1>
</div>
<div>
{self.messages_html}
</div>
</div>
</body>
</html>
"""
def generate_chat_pdf(self) -> bytes:
"""
Generate a PDF from chat messages.
"""
try:
global FONTS_DIR
pdf = FPDF()
pdf.add_page()
# When running using `pip install` the static directory is in the site packages.
if not FONTS_DIR.exists():
FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts"
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf")
pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf")
pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf")
pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf")
pdf.set_font("NotoSans", size=12)
pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"])
pdf.set_auto_page_break(auto=True, margin=15)
# Build HTML messages
messages_html_list: List[str] = [
self._build_html_message(msg) for msg in self.form_data.messages
]
self.messages_html = "<div>" + "".join(messages_html_list) + "</div>"
# Generate full HTML body
self.html_body = self._generate_html_body()
pdf.write_html(self.html_body)
# Save the pdf with name .pdf
pdf_bytes = pdf.output()
return bytes(pdf_bytes)
except Exception as e:
raise e

View file

@ -26,7 +26,6 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
) )
line = f"data: {json.dumps(data)}\n\n" line = f"data: {json.dumps(data)}\n\n"
if done:
line += "data: [DONE]\n\n"
yield line yield line
yield "data: [DONE]\n\n"

View file

@ -60,7 +60,7 @@ def set_hsts(value: str):
pattern = r"^max-age=(\d+)(;includeSubDomains)?(;preload)?$" pattern = r"^max-age=(\d+)(;includeSubDomains)?(;preload)?$"
match = re.match(pattern, value, re.IGNORECASE) match = re.match(pattern, value, re.IGNORECASE)
if not match: if not match:
return "max-age=31536000;includeSubDomains" value = "max-age=31536000;includeSubDomains"
return {"Strict-Transport-Security": value} return {"Strict-Transport-Security": value}

View file

@ -70,22 +70,6 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
return template return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
template = replace_prompt_variable(template, prompt)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def replace_messages_variable(template: str, messages: list[str]) -> str: def replace_messages_variable(template: str, messages: list[str]) -> str:
def replacement_function(match): def replacement_function(match):
full_match = match.group(0) full_match = match.group(0)
@ -123,6 +107,62 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
return template return template
# {{prompt:middletruncate:8000}}
def title_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def tags_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def emoji_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
template = replace_prompt_variable(template, prompt)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def search_query_generation_template( def search_query_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None template: str, messages: list[dict], user: Optional[dict] = None
) -> str: ) -> str:

View file

@ -7,7 +7,7 @@ import jwt
from open_webui.apps.webui.models.users import Users from open_webui.apps.webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY from open_webui.env import WEBUI_SECRET_KEY
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext from passlib.context import CryptContext

View file

@ -12,6 +12,7 @@ passlib[bcrypt]==1.7.4
requests==2.32.3 requests==2.32.3
aiohttp==3.10.8 aiohttp==3.10.8
async-timeout
sqlalchemy==2.0.32 sqlalchemy==2.0.32
alembic==1.13.2 alembic==1.13.2
@ -39,16 +40,19 @@ langchain-community==0.2.12
langchain-chroma==0.1.4 langchain-chroma==0.1.4
fake-useragent==1.5.1 fake-useragent==1.5.1
chromadb==0.5.9 chromadb==0.5.15
pymilvus==2.4.7 pymilvus==2.4.7
qdrant-client~=1.12.0
sentence-transformers==3.0.1 sentence-transformers==3.2.0
colbert-ai==0.2.21 colbert-ai==0.2.21
einops==0.8.0 einops==0.8.0
ftfy==6.2.3 ftfy==6.2.3
pypdf==4.3.1 pypdf==4.3.1
xhtml2pdf==0.2.16
pymdown-extensions==10.11.2
docx2txt==0.8 docx2txt==0.8
python-pptx==1.0.0 python-pptx==1.0.0
unstructured==0.15.9 unstructured==0.15.9
@ -86,3 +90,5 @@ duckduckgo-search~=6.2.13
docker~=7.1.0 docker~=7.1.0
pytest~=8.3.2 pytest~=8.3.2
pytest-docker~=3.1.1 pytest-docker~=3.1.1
googleapis-common-protos==1.63.2

View file

@ -30,7 +30,7 @@ describe('Settings', () => {
// Select the first model // Select the first model
cy.get('button[aria-label="model-item"]').first().click(); cy.get('button[aria-label="model-item"]').first().click();
// Type a message // Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', { cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true force: true
}); });
// Send the message // Send the message
@ -50,7 +50,7 @@ describe('Settings', () => {
// Select the first model // Select the first model
cy.get('button[aria-label="model-item"]').first().click(); cy.get('button[aria-label="model-item"]').first().click();
// Type a message // Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', { cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true force: true
}); });
// Send the message // Send the message
@ -85,7 +85,7 @@ describe('Settings', () => {
// Select the first model // Select the first model
cy.get('button[aria-label="model-item"]').first().click(); cy.get('button[aria-label="model-item"]').first().click();
// Type a message // Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', { cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true force: true
}); });
// Send the message // Send the message

View file

@ -118,7 +118,7 @@ Navigate to the apache sites-available directory:
`nano models.server.city.conf` # match this with your ollama server domain `nano models.server.city.conf` # match this with your ollama server domain
Add the folloing virtualhost containing this example (modify as needed): Add the following virtualhost containing this example (modify as needed):
``` ```

1040
package-lock.json generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.32", "version": "0.3.35",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
@ -52,6 +52,7 @@
"@codemirror/lang-python": "^6.1.6", "@codemirror/lang-python": "^6.1.6",
"@codemirror/language-data": "^6.5.1", "@codemirror/language-data": "^6.5.1",
"@codemirror/theme-one-dark": "^6.1.2", "@codemirror/theme-one-dark": "^6.1.2",
"@huggingface/transformers": "^3.0.0",
"@pyscript/core": "^0.4.32", "@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^2.0.0", "@sveltejs/adapter-node": "^2.0.0",
"@xyflow/svelte": "^0.1.19", "@xyflow/svelte": "^0.1.19",
@ -72,9 +73,19 @@
"js-sha256": "^0.10.1", "js-sha256": "^0.10.1",
"katex": "^0.16.9", "katex": "^0.16.9",
"marked": "^9.1.0", "marked": "^9.1.0",
"mermaid": "^10.9.1", "mermaid": "^10.9.3",
"paneforge": "^0.0.6", "paneforge": "^0.0.6",
"panzoom": "^9.4.3", "panzoom": "^9.4.3",
"prosemirror-commands": "^1.6.0",
"prosemirror-example-setup": "^1.2.3",
"prosemirror-history": "^1.4.1",
"prosemirror-keymap": "^1.2.2",
"prosemirror-markdown": "^1.13.1",
"prosemirror-model": "^1.23.0",
"prosemirror-schema-basic": "^1.2.3",
"prosemirror-schema-list": "^1.4.1",
"prosemirror-state": "^1.4.3",
"prosemirror-view": "^1.34.3",
"pyodide": "^0.26.1", "pyodide": "^0.26.1",
"socket.io-client": "^4.2.0", "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2", "sortablejs": "^1.15.2",

View file

@ -20,6 +20,7 @@ dependencies = [
"requests==2.32.3", "requests==2.32.3",
"aiohttp==3.10.8", "aiohttp==3.10.8",
"async-timeout",
"sqlalchemy==2.0.32", "sqlalchemy==2.0.32",
"alembic==1.13.2", "alembic==1.13.2",
@ -49,13 +50,14 @@ dependencies = [
"chromadb==0.5.9", "chromadb==0.5.9",
"pymilvus==2.4.7", "pymilvus==2.4.7",
"sentence-transformers==3.0.1", "sentence-transformers==3.2.0",
"colbert-ai==0.2.21", "colbert-ai==0.2.21",
"einops==0.8.0", "einops==0.8.0",
"ftfy==6.2.3", "ftfy==6.2.3",
"pypdf==4.3.1", "pypdf==4.3.1",
"xhtml2pdf==0.2.16",
"pymdown-extensions==10.11.2",
"docx2txt==0.8", "docx2txt==0.8",
"python-pptx==1.0.0", "python-pptx==1.0.0",
"unstructured==0.15.9", "unstructured==0.15.9",
@ -91,7 +93,9 @@ dependencies = [
"docker~=7.1.0", "docker~=7.1.0",
"pytest~=8.3.2", "pytest~=8.3.2",
"pytest-docker~=3.1.1" "pytest-docker~=3.1.1",
"googleapis-common-protos==1.63.2"
] ]
readme = "README.md" readme = "README.md"
requires-python = ">= 3.11, < 3.12.0a1" requires-python = ">= 3.11, < 3.12.0a1"

View file

@ -34,6 +34,14 @@ math {
@apply rounded-lg; @apply rounded-lg;
} }
.input-prose {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
}
.input-prose-sm {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line text-sm;
}
.markdown-prose { .markdown-prose {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
} }
@ -56,7 +64,7 @@ li p {
::-webkit-scrollbar-thumb { ::-webkit-scrollbar-thumb {
--tw-border-opacity: 1; --tw-border-opacity: 1;
background-color: rgba(217, 217, 227, 0.8); background-color: rgba(236, 236, 236, 0.8);
border-color: rgba(255, 255, 255, var(--tw-border-opacity)); border-color: rgba(255, 255, 255, var(--tw-border-opacity));
border-radius: 9999px; border-radius: 9999px;
border-width: 1px; border-width: 1px;
@ -64,7 +72,7 @@ li p {
/* Dark theme scrollbar styles */ /* Dark theme scrollbar styles */
.dark ::-webkit-scrollbar-thumb { .dark ::-webkit-scrollbar-thumb {
background-color: rgba(69, 69, 74, 0.8); /* Darker color for dark theme */ background-color: rgba(33, 33, 33, 0.8); /* Darker color for dark theme */
border-color: rgba(0, 0, 0, var(--tw-border-opacity)); border-color: rgba(0, 0, 0, var(--tw-border-opacity));
} }
@ -179,3 +187,21 @@ input[type='number'] {
.bg-gray-950-90 { .bg-gray-950-90 {
background-color: rgba(var(--color-gray-950, #0d0d0d), 0.9); background-color: rgba(var(--color-gray-950, #0d0d0d), 0.9);
} }
.ProseMirror {
@apply h-full min-h-fit max-h-full whitespace-pre-wrap;
}
.ProseMirror:focus {
outline: none;
}
.placeholder::after {
content: attr(data-placeholder);
cursor: text;
pointer-events: none;
float: left;
@apply absolute inset-0 z-0 text-gray-500;
}

View file

@ -180,6 +180,31 @@ export const userSignUp = async (
return res; return res;
}; };
export const userSignOut = async () => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/signout`, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
},
credentials: 'include'
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res;
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
};
export const addUser = async ( export const addUser = async (
token: string, token: string,
name: string, name: string,

View file

@ -32,6 +32,46 @@ export const createNewChat = async (token: string, chat: object) => {
return res; return res;
}; };
export const importChat = async (
token: string,
chat: object,
meta: object | null,
pinned?: boolean,
folderId?: string | null
) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/import`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
chat: chat,
meta: meta ?? {},
pinned: pinned,
folder_id: folderId
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getChatList = async (token: string = '', page: number | null = null) => { export const getChatList = async (token: string = '', page: number | null = null) => {
let error = null; let error = null;
const searchParams = new URLSearchParams(); const searchParams = new URLSearchParams();
@ -199,6 +239,40 @@ export const getChatListBySearchText = async (token: string, text: string, page:
throw error; throw error;
} }
return res.map((chat) => ({
...chat,
time_range: getTimeRange(chat.updated_at)
}));
};
export const getChatsByFolderId = async (token: string, folderId: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/folder/${folderId}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res; return res;
}; };
@ -264,10 +338,10 @@ export const getAllUserChats = async (token: string) => {
return res; return res;
}; };
export const getAllChatTags = async (token: string) => { export const getAllTags = async (token: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -295,6 +369,40 @@ export const getAllChatTags = async (token: string) => {
return res; return res;
}; };
export const getPinnedChatList = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/pinned`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res.map((chat) => ({
...chat,
time_range: getTimeRange(chat.updated_at)
}));
};
export const getChatListByTagName = async (token: string = '', tagName: string) => { export const getChatListByTagName = async (token: string = '', tagName: string) => {
let error = null; let error = null;
@ -396,11 +504,87 @@ export const getChatByShareId = async (token: string, share_id: string) => {
return res; return res;
}; };
export const getChatPinnedStatusById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pinned`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const toggleChatPinnedStatusById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pin`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const cloneChatById = async (token: string, id: string) => { export const cloneChatById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, {
method: 'GET', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -466,11 +650,46 @@ export const shareChatById = async (token: string, id: string) => {
return res; return res;
}; };
export const updateChatFolderIdById = async (token: string, id: string, folderId?: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/folder`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
folder_id: folderId
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const archiveChatById = async (token: string, id: string) => { export const archiveChatById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, {
method: 'GET', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -640,8 +859,7 @@ export const addTagById = async (token: string, id: string, tagName: string) =>
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
tag_name: tagName, name: tagName
chat_id: id
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -652,8 +870,7 @@ export const addTagById = async (token: string, id: string, tagName: string) =>
return json; return json;
}) })
.catch((err) => { .catch((err) => {
error = err; error = err.detail;
console.log(err); console.log(err);
return null; return null;
}); });
@ -676,8 +893,7 @@ export const deleteTagById = async (token: string, id: string, tagName: string)
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
tag_name: tagName, name: tagName
chat_id: id
}) })
}) })
.then(async (res) => { .then(async (res) => {

View file

@ -0,0 +1,246 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const getConfig = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/config`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateConfig = async (token: string, config: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/config`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...config
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getAllFeedbacks = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedbacks/all`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const exportAllFeedbacks = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedbacks/all/export`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const createNewFeedback = async (token: string, feedback: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...feedback
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFeedbackById = async (token: string, feedbackId: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFeedbackById = async (token: string, feedbackId: string, feedback: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...feedback
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFeedbackById = async (token: string, feedbackId: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -0,0 +1,269 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewFolder = async (token: string, name: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
name: name
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFolders = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFolderById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFolderNameById = async (token: string, id: string, name: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
name: name
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFolderIsExpandedById = async (
token: string,
id: string,
isExpanded: boolean
) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update/expanded`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
is_expanded: isExpanded
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFolderParentIdById = async (token: string, id: string, parentId?: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update/parent`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
parent_id: parentId
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
type FolderItems = {
chat_ids: string[];
file_ids: string[];
};
export const updateFolderItemsById = async (token: string, id: string, items: FolderItems) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update/items`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
items: items
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFolderById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -208,7 +208,7 @@ export const updateTaskConfig = async (token: string, config: object) => {
export const generateTitle = async ( export const generateTitle = async (
token: string = '', token: string = '',
model: string, model: string,
prompt: string, messages: string[],
chat_id?: string chat_id?: string
) => { ) => {
let error = null; let error = null;
@ -222,7 +222,7 @@ export const generateTitle = async (
}, },
body: JSON.stringify({ body: JSON.stringify({
model: model, model: model,
prompt: prompt, messages: messages,
...(chat_id && { chat_id: chat_id }) ...(chat_id && { chat_id: chat_id })
}) })
}) })
@ -245,6 +245,78 @@ export const generateTitle = async (
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
}; };
export const generateTags = async (
token: string = '',
model: string,
messages: string,
chat_id?: string
) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
messages: messages,
...(chat_id && { chat_id: chat_id })
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
try {
// Step 1: Safely extract the response string
const response = res?.choices[0]?.message?.content ?? '';
// Step 2: Attempt to fix common JSON format issues like single quotes
const sanitizedResponse = response.replace(/['`]/g, '"'); // Convert single quotes to double quotes for valid JSON
// Step 3: Find the relevant JSON block within the response
const jsonStartIndex = sanitizedResponse.indexOf('{');
const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
// Step 5: Parse the JSON block
const parsed = JSON.parse(jsonResponse);
// Step 6: If there's a "tags" key, return the tags array; otherwise, return an empty array
if (parsed && parsed.tags) {
return Array.isArray(parsed.tags) ? parsed.tags : [];
} else {
return [];
}
}
// If no valid JSON block found, return an empty array
return [];
} catch (e) {
// Catch and safely return empty array on any parsing errors
console.error('Failed to parse response: ', e);
return [];
}
};
export const generateEmoji = async ( export const generateEmoji = async (
token: string = '', token: string = '',
model: string, model: string,

View file

@ -7,6 +7,7 @@ type TextStreamUpdate = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
citations?: any; citations?: any;
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
selectedModelId?: any;
error?: any; error?: any;
usage?: ResponseUsage; usage?: ResponseUsage;
}; };
@ -71,6 +72,11 @@ async function* openAIStreamToIterator(
continue; continue;
} }
if (parsedData.selected_model_id) {
yield { done: false, value: '', selectedModelId: parsedData.selected_model_id };
continue;
}
yield { yield {
done: false, done: false,
value: parsedData.choices?.[0]?.delta?.content ?? '', value: parsedData.choices?.[0]?.delta?.content ?? '',

View file

@ -2,12 +2,13 @@
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { Confetti } from 'svelte-confetti'; import { Confetti } from 'svelte-confetti';
import { WEBUI_NAME, config } from '$lib/stores'; import { WEBUI_NAME, config, settings } from '$lib/stores';
import { WEBUI_VERSION } from '$lib/constants'; import { WEBUI_VERSION } from '$lib/constants';
import { getChangelog } from '$lib/apis'; import { getChangelog } from '$lib/apis';
import Modal from './common/Modal.svelte'; import Modal from './common/Modal.svelte';
import { updateUserSettings } from '$lib/apis/users';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -104,8 +105,10 @@
</div> </div>
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
on:click={() => { on:click={async () => {
localStorage.version = $config.version; localStorage.version = $config.version;
await settings.set({ ...$settings, ...{ version: $config.version } });
await updateUserSettings(localStorage.token, { ui: $settings });
show = false; show = false;
}} }}
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"

View file

@ -139,7 +139,7 @@
</button> </button>
</div> </div>
<div class="flex flex-col md:flex-row w-full px-5 pb-4 md:space-x-4 dark:text-gray-200"> <div class="flex flex-col md:flex-row w-full px-4 pb-3 md:space-x-4 dark:text-gray-200">
<div class=" flex flex-col w-full sm:flex-row sm:justify-center sm:space-x-6"> <div class=" flex flex-col w-full sm:flex-row sm:justify-center sm:space-x-6">
<form <form
class="flex flex-col w-full" class="flex flex-col w-full"
@ -147,9 +147,9 @@
submitHandler(); submitHandler();
}} }}
> >
<div class="flex text-center text-sm font-medium rounded-xl bg-transparent/10 p-1 mb-2"> <div class="flex text-center text-sm font-medium rounded-full bg-transparent/10 p-1 mb-2">
<button <button
class="w-full rounded-lg p-1.5 {tab === '' ? 'bg-gray-50 dark:bg-gray-850' : ''}" class="w-full rounded-full p-1.5 {tab === '' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
type="button" type="button"
on:click={() => { on:click={() => {
tab = ''; tab = '';
@ -157,7 +157,9 @@
> >
<button <button
class="w-full rounded-lg p-1 {tab === 'import' ? 'bg-gray-50 dark:bg-gray-850' : ''}" class="w-full rounded-full p-1 {tab === 'import'
? 'bg-gray-50 dark:bg-gray-850'
: ''}"
type="button" type="button"
on:click={() => { on:click={() => {
tab = 'import'; tab = 'import';
@ -183,7 +185,7 @@
</div> </div>
</div> </div>
<div class="flex flex-col w-full mt-2"> <div class="flex flex-col w-full mt-1">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Name')}</div> <div class=" mb-1 text-xs text-gray-500">{$i18n.t('Name')}</div>
<div class="flex-1"> <div class="flex-1">
@ -198,7 +200,7 @@
</div> </div>
</div> </div>
<hr class=" dark:border-gray-800 my-3 w-full" /> <hr class=" border-gray-50 dark:border-gray-850 my-2.5 w-full" />
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Email')}</div> <div class=" mb-1 text-xs text-gray-500">{$i18n.t('Email')}</div>
@ -209,13 +211,12 @@
type="email" type="email"
bind:value={_user.email} bind:value={_user.email}
placeholder={$i18n.t('Enter Your Email')} placeholder={$i18n.t('Enter Your Email')}
autocomplete="off"
required required
/> />
</div> </div>
</div> </div>
<div class="flex flex-col w-full mt-2"> <div class="flex flex-col w-full mt-1">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div> <div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div>
<div class="flex-1"> <div class="flex-1">
@ -271,13 +272,13 @@
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed' ? ' cursor-not-allowed'
: ''}" : ''}"
type="submit" type="submit"
disabled={loading} disabled={loading}
> >
{$i18n.t('Submit')} {$i18n.t('Save')}
{#if loading} {#if loading}
<div class="ml-2 self-center"> <div class="ml-2 self-center">

View file

@ -0,0 +1,677 @@
<script lang="ts">
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
import { onMount, getContext } from 'svelte';
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
dayjs.extend(relativeTime);
import * as ort from 'onnxruntime-web';
import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
const EMBEDDING_MODEL = 'TaylorAI/bge-micro-v2';
let tokenizer = null;
let model = null;
import { models } from '$lib/stores';
import { deleteFeedbackById, exportAllFeedbacks, getAllFeedbacks } from '$lib/apis/evaluations';
import FeedbackMenu from './Evaluations/FeedbackMenu.svelte';
import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte';
import Tooltip from '../common/Tooltip.svelte';
import Badge from '../common/Badge.svelte';
import Pagination from '../common/Pagination.svelte';
import MagnifyingGlass from '../icons/MagnifyingGlass.svelte';
import Share from '../icons/Share.svelte';
import CloudArrowUp from '../icons/CloudArrowUp.svelte';
import { toast } from 'svelte-sonner';
import Spinner from '../common/Spinner.svelte';
import DocumentArrowUpSolid from '../icons/DocumentArrowUpSolid.svelte';
import DocumentArrowDown from '../icons/DocumentArrowDown.svelte';
import ArrowDownTray from '../icons/ArrowDownTray.svelte';
const i18n = getContext('i18n');
let rankedModels = [];
let feedbacks = [];
let query = '';
let page = 1;
let tagEmbeddings = new Map();
let loaded = false;
let loadingLeaderboard = true;
let debounceTimer;
$: paginatedFeedbacks = feedbacks.slice((page - 1) * 10, page * 10);
type Feedback = {
id: string;
data: {
rating: number;
model_id: string;
sibling_model_ids: string[] | null;
reason: string;
comment: string;
tags: string[];
};
user: {
name: string;
profile_image_url: string;
};
updated_at: number;
};
type ModelStats = {
rating: number;
won: number;
lost: number;
};
//////////////////////
//
// Rank models by Elo rating
//
//////////////////////
const rankHandler = async (similarities: Map<string, number> = new Map()) => {
const modelStats = calculateModelStats(feedbacks, similarities);
rankedModels = $models
.filter((m) => m?.owned_by !== 'arena' && (m?.info?.meta?.hidden ?? false) !== true)
.map((model) => {
const stats = modelStats.get(model.id);
return {
...model,
rating: stats ? Math.round(stats.rating) : '-',
stats: {
count: stats ? stats.won + stats.lost : 0,
won: stats ? stats.won.toString() : '-',
lost: stats ? stats.lost.toString() : '-'
}
};
})
.sort((a, b) => {
if (a.rating === '-' && b.rating !== '-') return 1;
if (b.rating === '-' && a.rating !== '-') return -1;
if (a.rating !== '-' && b.rating !== '-') return b.rating - a.rating;
return a.name.localeCompare(b.name);
});
loadingLeaderboard = false;
};
function calculateModelStats(
feedbacks: Feedback[],
similarities: Map<string, number>
): Map<string, ModelStats> {
const stats = new Map<string, ModelStats>();
const K = 32;
function getOrDefaultStats(modelId: string): ModelStats {
return stats.get(modelId) || { rating: 1000, won: 0, lost: 0 };
}
function updateStats(modelId: string, ratingChange: number, outcome: number) {
const currentStats = getOrDefaultStats(modelId);
currentStats.rating += ratingChange;
if (outcome === 1) currentStats.won++;
else if (outcome === 0) currentStats.lost++;
stats.set(modelId, currentStats);
}
function calculateEloChange(
ratingA: number,
ratingB: number,
outcome: number,
similarity: number
): number {
const expectedScore = 1 / (1 + Math.pow(10, (ratingB - ratingA) / 400));
return K * (outcome - expectedScore) * similarity;
}
feedbacks.forEach((feedback) => {
const modelA = feedback.data.model_id;
const statsA = getOrDefaultStats(modelA);
let outcome: number;
switch (feedback.data.rating.toString()) {
case '1':
outcome = 1;
break;
case '-1':
outcome = 0;
break;
default:
return; // Skip invalid ratings
}
// If the query is empty, set similarity to 1, else get the similarity from the map
const similarity = query !== '' ? similarities.get(feedback.id) || 0 : 1;
const opponents = feedback.data.sibling_model_ids || [];
opponents.forEach((modelB) => {
const statsB = getOrDefaultStats(modelB);
const changeA = calculateEloChange(statsA.rating, statsB.rating, outcome, similarity);
const changeB = calculateEloChange(statsB.rating, statsA.rating, 1 - outcome, similarity);
updateStats(modelA, changeA, outcome);
updateStats(modelB, changeB, 1 - outcome);
});
});
return stats;
}
//////////////////////
//
// Calculate cosine similarity
//
//////////////////////
const cosineSimilarity = (vecA, vecB) => {
// Ensure the lengths of the vectors are the same
if (vecA.length !== vecB.length) {
throw new Error('Vectors must be the same length');
}
// Calculate the dot product
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < vecA.length; i++) {
dotProduct += vecA[i] * vecB[i];
normA += vecA[i] ** 2;
normB += vecB[i] ** 2;
}
// Calculate the magnitudes
normA = Math.sqrt(normA);
normB = Math.sqrt(normB);
// Avoid division by zero
if (normA === 0 || normB === 0) {
return 0;
}
// Return the cosine similarity
return dotProduct / (normA * normB);
};
const calculateMaxSimilarity = (queryEmbedding, tagEmbeddings: Map<string, number[]>) => {
let maxSimilarity = 0;
for (const tagEmbedding of tagEmbeddings.values()) {
const similarity = cosineSimilarity(queryEmbedding, tagEmbedding);
maxSimilarity = Math.max(maxSimilarity, similarity);
}
return maxSimilarity;
};
//////////////////////
//
// Embedding functions
//
//////////////////////
const loadEmbeddingModel = async () => {
// Check if the tokenizer and model are already loaded and stored in the window object
if (!window.tokenizer) {
window.tokenizer = await AutoTokenizer.from_pretrained(EMBEDDING_MODEL);
}
if (!window.model) {
window.model = await AutoModel.from_pretrained(EMBEDDING_MODEL);
}
// Use the tokenizer and model from the window object
tokenizer = window.tokenizer;
model = window.model;
// Pre-compute embeddings for all unique tags
const allTags = new Set(feedbacks.flatMap((feedback) => feedback.data.tags || []));
await getTagEmbeddings(Array.from(allTags));
};
const getEmbeddings = async (text: string) => {
const tokens = await tokenizer(text);
const output = await model(tokens);
// Perform mean pooling on the last hidden states
const embeddings = output.last_hidden_state.mean(1);
return embeddings.ort_tensor.data;
};
const getTagEmbeddings = async (tags: string[]) => {
const embeddings = new Map();
for (const tag of tags) {
if (!tagEmbeddings.has(tag)) {
tagEmbeddings.set(tag, await getEmbeddings(tag));
}
embeddings.set(tag, tagEmbeddings.get(tag));
}
return embeddings;
};
const debouncedQueryHandler = async () => {
loadingLeaderboard = true;
if (query.trim() === '') {
rankHandler();
return;
}
clearTimeout(debounceTimer);
debounceTimer = setTimeout(async () => {
const queryEmbedding = await getEmbeddings(query);
const similarities = new Map<string, number>();
for (const feedback of feedbacks) {
const feedbackTags = feedback.data.tags || [];
const tagEmbeddings = await getTagEmbeddings(feedbackTags);
const maxSimilarity = calculateMaxSimilarity(queryEmbedding, tagEmbeddings);
similarities.set(feedback.id, maxSimilarity);
}
rankHandler(similarities);
}, 1500); // Debounce for 1.5 seconds
};
$: query, debouncedQueryHandler();
//////////////////////
//
// CRUD operations
//
//////////////////////
const deleteFeedbackHandler = async (feedbackId: string) => {
const response = await deleteFeedbackById(localStorage.token, feedbackId).catch((err) => {
toast.error(err);
return null;
});
if (response) {
feedbacks = feedbacks.filter((f) => f.id !== feedbackId);
}
};
const shareHandler = async () => {
toast.success($i18n.t('Redirecting you to OpenWebUI Community'));
// remove snapshot from feedbacks
const feedbacksToShare = feedbacks.map((f) => {
const { snapshot, user, ...rest } = f;
return rest;
});
console.log(feedbacksToShare);
const url = 'https://openwebui.com';
const tab = await window.open(`${url}/leaderboard`, '_blank');
// Define the event handler function
const messageHandler = (event) => {
if (event.origin !== url) return;
if (event.data === 'loaded') {
tab.postMessage(JSON.stringify(feedbacksToShare), '*');
// Remove the event listener after handling the message
window.removeEventListener('message', messageHandler);
}
};
window.addEventListener('message', messageHandler, false);
};
const exportHandler = async () => {
const _feedbacks = await exportAllFeedbacks(localStorage.token).catch((err) => {
toast.error(err);
return null;
});
if (_feedbacks) {
let blob = new Blob([JSON.stringify(_feedbacks)], {
type: 'application/json'
});
saveAs(blob, `feedback-history-export-${Date.now()}.json`);
}
};
onMount(async () => {
feedbacks = await getAllFeedbacks(localStorage.token);
loaded = true;
rankHandler();
});
</script>
{#if loaded}
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
<div class="flex md:self-center text-lg font-medium px-0.5 shrink-0 items-center">
<div class=" gap-1">
{$i18n.t('Leaderboard')}
</div>
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" />
<span class="text-lg font-medium text-gray-500 dark:text-gray-300 mr-1.5"
>{rankedModels.length}</span
>
</div>
<div class=" flex space-x-2">
<Tooltip content={$i18n.t('Re-rank models by topic similarity')}>
<div class="flex flex-1">
<div class=" self-center ml-1 mr-3">
<MagnifyingGlass className="size-3" />
</div>
<input
class=" w-full text-sm pr-4 py-1 rounded-r-xl outline-none bg-transparent"
bind:value={query}
placeholder={$i18n.t('Search')}
on:focus={() => {
loadEmbeddingModel();
}}
/>
</div>
</Tooltip>
</div>
</div>
<div
class="scrollbar-hidden relative whitespace-nowrap overflow-x-auto max-w-full rounded pt-0.5"
>
{#if loadingLeaderboard}
<div class=" absolute top-0 bottom-0 left-0 right-0 flex">
<div class="m-auto">
<Spinner />
</div>
</div>
{/if}
{#if (rankedModels ?? []).length === 0}
<div class="text-center text-xs text-gray-500 dark:text-gray-400 py-1">
{$i18n.t('No models found')}
</div>
{:else}
<table
class="w-full text-sm text-left text-gray-500 dark:text-gray-400 table-auto max-w-full rounded {loadingLeaderboard
? 'opacity-20'
: ''}"
>
<thead
class="text-xs text-gray-700 uppercase bg-gray-50 dark:bg-gray-850 dark:text-gray-400 -translate-y-0.5"
>
<tr class="">
<th scope="col" class="px-3 py-1.5 cursor-pointer select-none w-3">
{$i18n.t('RK')}
</th>
<th scope="col" class="px-3 py-1.5 cursor-pointer select-none">
{$i18n.t('Model')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-fit">
{$i18n.t('Rating')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-5">
{$i18n.t('Won')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-5">
{$i18n.t('Lost')}
</th>
</tr>
</thead>
<tbody class="">
{#each rankedModels as model, modelIdx (model.id)}
<tr class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs group">
<td class="px-3 py-1.5 text-left font-medium text-gray-900 dark:text-white w-fit">
<div class=" line-clamp-1">
{model?.rating !== '-' ? modelIdx + 1 : '-'}
</div>
</td>
<td class="px-3 py-1.5 flex flex-col justify-center">
<div class="flex items-center gap-2">
<div class="flex-shrink-0">
<img
src={model?.info?.meta?.profile_image_url ?? '/favicon.png'}
alt={model.name}
class="size-5 rounded-full object-cover shrink-0"
/>
</div>
<div class="font-medium text-gray-800 dark:text-gray-200 pr-4">
{model.name}
</div>
</div>
</td>
<td class="px-3 py-1.5 text-right font-medium text-gray-900 dark:text-white w-max">
{model.rating}
</td>
<td class=" px-3 py-1.5 text-right font-semibold text-green-500">
<div class=" w-10">
{#if model.stats.won === '-'}
-
{:else}
<span class="hidden group-hover:inline"
>{((model.stats.won / model.stats.count) * 100).toFixed(1)}%</span
>
<span class=" group-hover:hidden">{model.stats.won}</span>
{/if}
</div>
</td>
<td class="px-3 py-1.5 text-right font-semibold text-red-500">
<div class=" w-10">
{#if model.stats.lost === '-'}
-
{:else}
<span class="hidden group-hover:inline"
>{((model.stats.lost / model.stats.count) * 100).toFixed(1)}%</span
>
<span class=" group-hover:hidden">{model.stats.lost}</span>
{/if}
</div>
</td>
</tr>
{/each}
</tbody>
</table>
{/if}
</div>
<div class=" text-gray-500 text-xs mt-1.5 w-full flex justify-end">
<div class=" text-right">
<div class="line-clamp-1">
{$i18n.t(
'The evaluation leaderboard is based on the Elo rating system and is updated in real-time.'
)}
</div>
{$i18n.t(
'The leaderboard is currently in beta, and we may adjust the rating calculations as we refine the algorithm.'
)}
</div>
</div>
<div class="pb-4"></div>
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
<div class="flex md:self-center text-lg font-medium px-0.5">
{$i18n.t('Feedback History')}
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" />
<span class="text-lg font-medium text-gray-500 dark:text-gray-300">{feedbacks.length}</span>
</div>
<div>
<div>
<Tooltip content={$i18n.t('Export')}>
<button
class=" p-2 rounded-xl hover:bg-gray-100 dark:bg-gray-900 dark:hover:bg-gray-850 transition font-medium text-sm flex items-center space-x-1"
on:click={() => {
exportHandler();
}}
>
<ArrowDownTray className="size-3" />
</button>
</Tooltip>
</div>
</div>
</div>
<div
class="scrollbar-hidden relative whitespace-nowrap overflow-x-auto max-w-full rounded pt-0.5"
>
{#if (feedbacks ?? []).length === 0}
<div class="text-center text-xs text-gray-500 dark:text-gray-400 py-1">
{$i18n.t('No feedbacks found')}
</div>
{:else}
<table
class="w-full text-sm text-left text-gray-500 dark:text-gray-400 table-auto max-w-full rounded"
>
<thead
class="text-xs text-gray-700 uppercase bg-gray-50 dark:bg-gray-850 dark:text-gray-400 -translate-y-0.5"
>
<tr class="">
<th scope="col" class="px-3 text-right cursor-pointer select-none w-0">
{$i18n.t('User')}
</th>
<th scope="col" class="px-3 pr-1.5 cursor-pointer select-none">
{$i18n.t('Models')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-fit">
{$i18n.t('Result')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-0">
{$i18n.t('Updated At')}
</th>
<th scope="col" class="px-3 py-1.5 text-right cursor-pointer select-none w-0"> </th>
</tr>
</thead>
<tbody class="">
{#each paginatedFeedbacks as feedback (feedback.id)}
<tr class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs">
<td class=" py-0.5 text-right font-semibold">
<div class="flex justify-center">
<Tooltip content={feedback?.user?.name}>
<div class="flex-shrink-0">
<img
src={feedback?.user?.profile_image_url ?? '/user.png'}
alt={feedback?.user?.name}
class="size-5 rounded-full object-cover shrink-0"
/>
</div>
</Tooltip>
</div>
</td>
<td class=" py-1 pl-3 flex flex-col">
<div class="flex flex-col items-start gap-0.5 h-full">
<div class="flex flex-col h-full">
{#if feedback.data?.sibling_model_ids}
<div class="font-semibold text-gray-600 dark:text-gray-400 flex-1">
{feedback.data?.model_id}
</div>
<Tooltip content={feedback.data.sibling_model_ids.join(', ')}>
<div class=" text-[0.65rem] text-gray-600 dark:text-gray-400 line-clamp-1">
{#if feedback.data.sibling_model_ids.length > 2}
<!-- {$i18n.t('and {{COUNT}} more')} -->
{feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t(
'and {{COUNT}} more',
{ COUNT: feedback.data.sibling_model_ids.length - 2 }
)}
{:else}
{feedback.data.sibling_model_ids.join(', ')}
{/if}
</div>
</Tooltip>
{:else}
<div
class=" text-sm font-medium text-gray-600 dark:text-gray-400 flex-1 py-1.5"
>
{feedback.data?.model_id}
</div>
{/if}
</div>
</div>
</td>
<td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max">
<div class=" flex justify-end">
{#if feedback.data.rating.toString() === '1'}
<Badge type="info" content={$i18n.t('Won')} />
{:else if feedback.data.rating.toString() === '0'}
<Badge type="muted" content={$i18n.t('Draw')} />
{:else if feedback.data.rating.toString() === '-1'}
<Badge type="error" content={$i18n.t('Lost')} />
{/if}
</div>
</td>
<td class=" px-3 py-1 text-right font-medium">
{dayjs(feedback.updated_at * 1000).fromNow()}
</td>
<td class=" px-3 py-1 text-right font-semibold">
<FeedbackMenu
on:delete={(e) => {
deleteFeedbackHandler(feedback.id);
}}
>
<button
class="self-center w-fit text-sm p-1.5 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
>
<EllipsisHorizontal />
</button>
</FeedbackMenu>
</td>
</tr>
{/each}
</tbody>
</table>
{/if}
</div>
{#if feedbacks.length > 0}
<div class=" flex flex-col justify-end w-full text-right gap-1">
<div class="line-clamp-1 text-gray-500 text-xs">
{$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')}
</div>
<div class="flex space-x-1 ml-auto">
<Tooltip
content={$i18n.t(
'To protect your privacy, only ratings, model IDs, tags, and metadata are shared from your feedback—your chat logs remain private and are not included.'
)}
>
<button
class="flex text-xs items-center px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-200 transition"
on:click={async () => {
shareHandler();
}}
>
<div class=" self-center mr-2 font-medium line-clamp-1">
{$i18n.t('Share to OpenWebUI Community')}
</div>
<div class=" self-center">
<CloudArrowUp className="size-3" strokeWidth="3" />
</div>
</button>
</Tooltip>
</div>
</div>
{/if}
{#if feedbacks.length > 10}
<Pagination bind:page count={feedbacks.length} perPage={10} />
{/if}
<div class="pb-12"></div>
{/if}

View file

@ -0,0 +1,46 @@
<script lang="ts">
import { DropdownMenu } from 'bits-ui';
import { flyAndScale } from '$lib/utils/transitions';
import { getContext, createEventDispatcher } from 'svelte';
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import Dropdown from '$lib/components/common/Dropdown.svelte';
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
import Pencil from '$lib/components/icons/Pencil.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Download from '$lib/components/icons/Download.svelte';
let show = false;
</script>
<Dropdown bind:show on:change={(e) => {}}>
<Tooltip content={$i18n.t('More')}>
<slot />
</Tooltip>
<div slot="content">
<DropdownMenu.Content
class="w-full max-w-[150px] rounded-xl px-1 py-1.5 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-lg"
sideOffset={-2}
side="bottom"
align="start"
transition={flyAndScale}
>
<DropdownMenu.Item
class="flex gap-2 items-center px-3 py-1.5 text-sm cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
on:click={() => {
dispatch('delete');
show = false;
}}
>
<GarbageBin strokeWidth="2" />
<div class="flex items-center">{$i18n.t('Delete')}</div>
</DropdownMenu.Item>
</DropdownMenu.Content>
</div>
</Dropdown>

View file

@ -17,6 +17,9 @@
import WebSearch from './Settings/WebSearch.svelte'; import WebSearch from './Settings/WebSearch.svelte';
import { config } from '$lib/stores'; import { config } from '$lib/stores';
import { getBackendConfig } from '$lib/apis'; import { getBackendConfig } from '$lib/apis';
import ChartBar from '../icons/ChartBar.svelte';
import DocumentChartBar from '../icons/DocumentChartBar.svelte';
import Evaluations from './Settings/Evaluations.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -42,7 +45,7 @@
class="tabs flex flex-row overflow-x-auto space-x-1 max-w-full lg:space-x-0 lg:space-y-1 lg:flex-col lg:flex-none lg:w-44 dark:text-gray-200 text-xs text-left scrollbar-none" class="tabs flex flex-row overflow-x-auto space-x-1 max-w-full lg:space-x-0 lg:space-y-1 lg:flex-col lg:flex-none lg:w-44 dark:text-gray-200 text-xs text-left scrollbar-none"
> >
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 lg:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 lg:flex-none flex text-right transition {selectedTab ===
'general' 'general'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -68,7 +71,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'users' 'users'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -92,7 +95,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'connections' 'connections'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -116,7 +119,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'models' 'models'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -142,7 +145,22 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'evaluations'
? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
on:click={() => {
selectedTab = 'evaluations';
}}
>
<div class=" self-center mr-2">
<DocumentChartBar />
</div>
<div class=" self-center">{$i18n.t('Evaluations')}</div>
</button>
<button
class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'documents' 'documents'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -172,7 +190,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'web' 'web'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -196,7 +214,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'interface' 'interface'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -222,7 +240,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'audio' 'audio'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -249,7 +267,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'images' 'images'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -275,7 +293,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'pipelines' 'pipelines'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -305,7 +323,7 @@
</button> </button>
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'db' 'db'
? 'bg-gray-100 dark:bg-gray-800' ? 'bg-gray-100 dark:bg-gray-800'
: ' hover:bg-gray-50 dark:hover:bg-gray-850'}" : ' hover:bg-gray-50 dark:hover:bg-gray-850'}"
@ -357,6 +375,8 @@
/> />
{:else if selectedTab === 'models'} {:else if selectedTab === 'models'}
<Models /> <Models />
{:else if selectedTab === 'evaluations'}
<Evaluations />
{:else if selectedTab === 'documents'} {:else if selectedTab === 'documents'}
<Documents <Documents
on:save={async () => { on:save={async () => {

View file

@ -38,6 +38,9 @@
let STT_OPENAI_API_KEY = ''; let STT_OPENAI_API_KEY = '';
let STT_ENGINE = ''; let STT_ENGINE = '';
let STT_MODEL = ''; let STT_MODEL = '';
let STT_WHISPER_MODEL = '';
let STT_WHISPER_MODEL_LOADING = false;
// eslint-disable-next-line no-undef // eslint-disable-next-line no-undef
let voices: SpeechSynthesisVoice[] = []; let voices: SpeechSynthesisVoice[] = [];
@ -99,18 +102,23 @@
OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL, OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL,
OPENAI_API_KEY: STT_OPENAI_API_KEY, OPENAI_API_KEY: STT_OPENAI_API_KEY,
ENGINE: STT_ENGINE, ENGINE: STT_ENGINE,
MODEL: STT_MODEL MODEL: STT_MODEL,
WHISPER_MODEL: STT_WHISPER_MODEL
} }
}); });
if (res) { if (res) {
saveHandler(); saveHandler();
getBackendConfig() config.set(await getBackendConfig());
.then(config.set)
.catch(() => {});
} }
}; };
const sttModelUpdateHandler = async () => {
STT_WHISPER_MODEL_LOADING = true;
await updateConfigHandler();
STT_WHISPER_MODEL_LOADING = false;
};
onMount(async () => { onMount(async () => {
const res = await getAudioConfig(localStorage.token); const res = await getAudioConfig(localStorage.token);
@ -134,6 +142,7 @@
STT_ENGINE = res.stt.ENGINE; STT_ENGINE = res.stt.ENGINE;
STT_MODEL = res.stt.MODEL; STT_MODEL = res.stt.MODEL;
STT_WHISPER_MODEL = res.stt.WHISPER_MODEL;
} }
await getVoices(); await getVoices();
@ -201,6 +210,88 @@
</div> </div>
</div> </div>
</div> </div>
{:else if STT_ENGINE === ''}
<div>
<div class=" mb-1.5 text-sm font-medium">{$i18n.t('STT Model')}</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Set whisper model')}
bind:value={STT_WHISPER_MODEL}
/>
</div>
<button
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
sttModelUpdateHandler();
}}
disabled={STT_WHISPER_MODEL_LOADING}
>
{#if STT_WHISPER_MODEL_LOADING}
<div class="self-center">
<svg
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
>
<style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style>
<path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/>
<path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/>
</svg>
</div>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
/>
<path
d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
/>
</svg>
{/if}
</button>
</div>
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(`Open WebUI uses faster-whisper internally.`)}
<a
class=" hover:underline dark:text-gray-200 text-gray-800"
href="https://github.com/SYSTRAN/faster-whisper"
target="_blank"
>
{$i18n.t(
`Click here to learn more about faster-whisper and see the available models.`
)}
</a>
</div>
</div>
{/if} {/if}
</div> </div>
@ -339,7 +430,7 @@
<datalist id="tts-model-list"> <datalist id="tts-model-list">
{#each models as model} {#each models as model}
<option value={model.id} /> <option value={model.id} class="bg-gray-50 dark:bg-gray-700" />
{/each} {/each}
</datalist> </datalist>
</div> </div>
@ -380,7 +471,7 @@
<datalist id="tts-model-list"> <datalist id="tts-model-list">
{#each models as model} {#each models as model}
<option value={model.id} /> <option value={model.id} class="bg-gray-50 dark:bg-gray-700" />
{/each} {/each}
</datalist> </datalist>
</div> </div>
@ -460,7 +551,7 @@
</div> </div>
<div class="flex justify-end text-sm font-medium"> <div class="flex justify-end text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}

View file

@ -439,7 +439,7 @@
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}

View file

@ -26,6 +26,9 @@
import ResetVectorDBConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import ResetVectorDBConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Switch from '$lib/components/common/Switch.svelte';
import { text } from '@sveltejs/kit';
import Textarea from '$lib/components/common/Textarea.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -48,6 +51,7 @@
let tikaServerUrl = ''; let tikaServerUrl = '';
let showTikaServerUrl = false; let showTikaServerUrl = false;
let textSplitter = '';
let chunkSize = 0; let chunkSize = 0;
let chunkOverlap = 0; let chunkOverlap = 0;
let pdfExtractImages = true; let pdfExtractImages = true;
@ -177,6 +181,7 @@
max_count: fileMaxCount === '' ? null : fileMaxCount max_count: fileMaxCount === '' ? null : fileMaxCount
}, },
chunk: { chunk: {
text_splitter: textSplitter,
chunk_overlap: chunkOverlap, chunk_overlap: chunkOverlap,
chunk_size: chunkSize chunk_size: chunkSize
}, },
@ -222,11 +227,13 @@
await setRerankingConfig(); await setRerankingConfig();
querySettings = await getQuerySettings(localStorage.token); querySettings = await getQuerySettings(localStorage.token);
const res = await getRAGConfig(localStorage.token); const res = await getRAGConfig(localStorage.token);
if (res) { if (res) {
pdfExtractImages = res.pdf_extract_images; pdfExtractImages = res.pdf_extract_images;
textSplitter = res.chunk.text_splitter;
chunkSize = res.chunk.chunk_size; chunkSize = res.chunk.chunk_size;
chunkOverlap = res.chunk.chunk_overlap; chunkOverlap = res.chunk.chunk_overlap;
@ -535,13 +542,13 @@
<hr class=" dark:border-gray-850" /> <hr class=" dark:border-gray-850" />
<div class=""> <div class="">
<div class="text-sm font-medium">{$i18n.t('Content Extraction')}</div> <div class="text-sm font-medium mb-1">{$i18n.t('Content Extraction')}</div>
<div class="flex w-full justify-between mt-2"> <div class="flex w-full justify-between">
<div class="self-center text-xs font-medium">{$i18n.t('Engine')}</div> <div class="self-center text-xs font-medium">{$i18n.t('Engine')}</div>
<div class="flex items-center relative"> <div class="flex items-center relative">
<select <select
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" class="dark:bg-gray-900 w-fit pr-8 rounded px-2 text-xs bg-transparent outline-none text-right"
bind:value={contentExtractionEngine} bind:value={contentExtractionEngine}
on:change={(e) => { on:change={(e) => {
showTikaServerUrl = e.target.value === 'tika'; showTikaServerUrl = e.target.value === 'tika';
@ -554,7 +561,7 @@
</div> </div>
{#if showTikaServerUrl} {#if showTikaServerUrl}
<div class="flex w-full mt-2"> <div class="flex w-full mt-1">
<div class="flex-1 mr-2"> <div class="flex-1 mr-2">
<input <input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none" class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
@ -569,9 +576,136 @@
<hr class=" dark:border-gray-850" /> <hr class=" dark:border-gray-850" />
<div class=" "> <div class=" ">
<div class="text-sm font-medium">{$i18n.t('Files')}</div> <div class=" text-sm font-medium mb-1">{$i18n.t('Query Params')}</div>
<div class=" my-2 flex gap-1.5"> <div class=" flex gap-1.5">
<div class="flex flex-col w-full gap-1">
<div class=" text-xs font-medium w-full">{$i18n.t('Top K')}</div>
<div class="w-full">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Top K')}
bind:value={querySettings.k}
autocomplete="off"
min="0"
/>
</div>
</div>
{#if querySettings.hybrid === true}
<div class=" flex flex-col w-full gap-1">
<div class="text-xs font-medium w-full">
{$i18n.t('Minimum Score')}
</div>
<div class="w-full">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
step="0.01"
placeholder={$i18n.t('Enter Score')}
bind:value={querySettings.r}
autocomplete="off"
min="0.0"
title={$i18n.t('The score should be a value between 0.0 (0%) and 1.0 (100%).')}
/>
</div>
</div>
{/if}
</div>
{#if querySettings.hybrid === true}
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(
'Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.'
)}
</div>
{/if}
<div class="mt-2">
<div class=" mb-1 text-xs font-medium">{$i18n.t('RAG Template')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={querySettings.template}
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
/>
</Tooltip>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div class=" ">
<div class="mb-1 text-sm font-medium">{$i18n.t('Chunk Params')}</div>
<div class="flex w-full justify-between mb-1.5">
<div class="self-center text-xs font-medium">{$i18n.t('Text Splitter')}</div>
<div class="flex items-center relative">
<select
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 text-xs bg-transparent outline-none text-right"
bind:value={textSplitter}
>
<option value="">{$i18n.t('Default')} ({$i18n.t('Character')})</option>
<option value="token">{$i18n.t('Token')} ({$i18n.t('Tiktoken')})</option>
</select>
</div>
</div>
<div class=" flex gap-1.5">
<div class=" w-full justify-between">
<div class="self-center text-xs font-medium min-w-fit mb-1">{$i18n.t('Chunk Size')}</div>
<div class="self-center">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Size')}
bind:value={chunkSize}
autocomplete="off"
min="0"
/>
</div>
</div>
<div class="w-full">
<div class=" self-center text-xs font-medium min-w-fit mb-1">
{$i18n.t('Chunk Overlap')}
</div>
<div class="self-center">
<input
class="w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Overlap')}
bind:value={chunkOverlap}
autocomplete="off"
min="0"
/>
</div>
</div>
</div>
<div class="my-2">
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">{$i18n.t('PDF Extract Images (OCR)')}</div>
<div>
<Switch bind:state={pdfExtractImages} />
</div>
</div>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div class="">
<div class="text-sm font-medium mb-1">{$i18n.t('Files')}</div>
<div class=" flex gap-1.5">
<div class="w-full"> <div class="w-full">
<div class=" self-center text-xs font-medium min-w-fit mb-1"> <div class=" self-center text-xs font-medium min-w-fit mb-1">
{$i18n.t('Max Upload Size')} {$i18n.t('Max Upload Size')}
@ -623,128 +757,6 @@
<hr class=" dark:border-gray-850" /> <hr class=" dark:border-gray-850" />
<div class=" ">
<div class=" text-sm font-medium">{$i18n.t('Query Params')}</div>
<div class=" flex gap-1">
<div class=" flex w-full justify-between">
<div class="self-center text-xs font-medium min-w-fit">{$i18n.t('Top K')}</div>
<div class="self-center p-3">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Top K')}
bind:value={querySettings.k}
autocomplete="off"
min="0"
/>
</div>
</div>
{#if querySettings.hybrid === true}
<div class=" flex w-full justify-between">
<div class=" self-center text-xs font-medium min-w-fit">
{$i18n.t('Minimum Score')}
</div>
<div class="self-center p-3">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
step="0.01"
placeholder={$i18n.t('Enter Score')}
bind:value={querySettings.r}
autocomplete="off"
min="0.0"
title={$i18n.t('The score should be a value between 0.0 (0%) and 1.0 (100%).')}
/>
</div>
</div>
{/if}
</div>
{#if querySettings.hybrid === true}
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(
'Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.'
)}
</div>
<hr class=" dark:border-gray-850 my-3" />
{/if}
<div>
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<textarea
bind:value={querySettings.template}
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
class="w-full rounded-lg px-4 py-3 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="4"
/>
</Tooltip>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div class=" ">
<div class=" text-sm font-medium">{$i18n.t('Chunk Params')}</div>
<div class=" my-2 flex gap-1.5">
<div class=" w-full justify-between">
<div class="self-center text-xs font-medium min-w-fit mb-1">{$i18n.t('Chunk Size')}</div>
<div class="self-center">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Size')}
bind:value={chunkSize}
autocomplete="off"
min="0"
/>
</div>
</div>
<div class="w-full">
<div class=" self-center text-xs font-medium min-w-fit mb-1">
{$i18n.t('Chunk Overlap')}
</div>
<div class="self-center">
<input
class="w-full rounded-lg py-1.5 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
placeholder={$i18n.t('Enter Chunk Overlap')}
bind:value={chunkOverlap}
autocomplete="off"
min="0"
/>
</div>
</div>
</div>
<div class="my-3">
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">{$i18n.t('PDF Extract Images (OCR)')}</div>
<button
class=" text-xs font-medium text-gray-500"
type="button"
on:click={() => {
pdfExtractImages = !pdfExtractImages;
}}>{pdfExtractImages ? $i18n.t('On') : $i18n.t('Off')}</button
>
</div>
</div>
</div>
<hr class=" dark:border-gray-850" />
<div> <div>
<button <button
class=" flex rounded-xl py-2 px-3.5 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition" class=" flex rounded-xl py-2 px-3.5 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
@ -794,13 +806,15 @@
/> />
</svg> </svg>
</div> </div>
<div class=" self-center text-sm font-medium">{$i18n.t('Reset Vector Storage')}</div> <div class=" self-center text-sm font-medium">
{$i18n.t('Reset Vector Storage/Knowledge')}
</div>
</button> </button>
</div> </div>
</div> </div>
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}

View file

@ -0,0 +1,158 @@
<script lang="ts">
import { toast } from 'svelte-sonner';
import { models, user } from '$lib/stores';
import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
const dispatch = createEventDispatcher();
import { getModels } from '$lib/apis';
import Switch from '$lib/components/common/Switch.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Plus from '$lib/components/icons/Plus.svelte';
import Model from './Evaluations/Model.svelte';
import ModelModal from './Evaluations/ModelModal.svelte';
import { getConfig, updateConfig } from '$lib/apis/evaluations';
const i18n = getContext('i18n');
let config = null;
let showAddModel = false;
const submitHandler = async () => {
config = await updateConfig(localStorage.token, config).catch((err) => {
toast.error(err);
return null;
});
if (config) {
toast.success('Settings saved successfully');
}
};
const addModelHandler = async (model) => {
config.EVALUATION_ARENA_MODELS.push(model);
config.EVALUATION_ARENA_MODELS = [...config.EVALUATION_ARENA_MODELS];
await submitHandler();
models.set(await getModels(localStorage.token));
};
const editModelHandler = async (model, modelIdx) => {
config.EVALUATION_ARENA_MODELS[modelIdx] = model;
config.EVALUATION_ARENA_MODELS = [...config.EVALUATION_ARENA_MODELS];
await submitHandler();
models.set(await getModels(localStorage.token));
};
const deleteModelHandler = async (modelIdx) => {
config.EVALUATION_ARENA_MODELS = config.EVALUATION_ARENA_MODELS.filter(
(m, mIdx) => mIdx !== modelIdx
);
await submitHandler();
models.set(await getModels(localStorage.token));
};
onMount(async () => {
if ($user.role === 'admin') {
config = await getConfig(localStorage.token).catch((err) => {
toast.error(err);
return null;
});
}
});
</script>
<ModelModal
bind:show={showAddModel}
on:submit={async (e) => {
addModelHandler(e.detail);
}}
/>
<form
class="flex flex-col h-full justify-between text-sm"
on:submit|preventDefault={() => {
submitHandler();
dispatch('save');
}}
>
<div class="overflow-y-scroll scrollbar-hidden h-full">
{#if config !== null}
<div class="">
<div class="text-sm font-medium mb-2">{$i18n.t('General Settings')}</div>
<div class=" mb-2">
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">{$i18n.t('Arena Models')}</div>
<Tooltip content={$i18n.t(`Message rating should be enabled to use this feature`)}>
<Switch bind:state={config.ENABLE_EVALUATION_ARENA_MODELS} />
</Tooltip>
</div>
</div>
{#if config.ENABLE_EVALUATION_ARENA_MODELS}
<hr class=" border-gray-50 dark:border-gray-700/10 my-2" />
<div class="flex justify-between items-center mb-2">
<div class="text-sm font-medium">{$i18n.t('Manage Arena Models')}</div>
<div>
<Tooltip content={$i18n.t('Add Arena Model')}>
<button
class="p-1"
type="button"
on:click={() => {
showAddModel = true;
}}
>
<Plus />
</button>
</Tooltip>
</div>
</div>
<div class="flex flex-col gap-2">
{#if (config?.EVALUATION_ARENA_MODELS ?? []).length > 0}
{#each config.EVALUATION_ARENA_MODELS as model, index}
<Model
{model}
on:edit={(e) => {
editModelHandler(e.detail, index);
}}
on:delete={(e) => {
deleteModelHandler(index);
}}
/>
{/each}
{:else}
<div class=" text-center text-xs text-gray-500">
{$i18n.t(
`Using the default arena model with all models. Click the plus button to add custom models.`
)}
</div>
{/if}
</div>
{/if}
</div>
{:else}
<div class="flex h-full justify-center">
<div class="my-auto">
<Spinner className="size-6" />
</div>
</div>
{/if}
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit"
>
{$i18n.t('Save')}
</button>
</div>
</form>

View file

@ -0,0 +1,63 @@
<script lang="ts">
import { getContext, createEventDispatcher } from 'svelte';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import Cog6 from '$lib/components/icons/Cog6.svelte';
import ModelModal from './ModelModal.svelte';
export let model;
let showModel = false;
</script>
<ModelModal
bind:show={showModel}
edit={true}
{model}
on:submit={async (e) => {
dispatch('edit', e.detail);
}}
on:delete={async () => {
dispatch('delete');
}}
/>
<div class="py-0.5">
<div class="flex justify-between items-center mb-1">
<div class="flex flex-col flex-1">
<div class="flex gap-2.5 items-center">
<img
src={model.meta.profile_image_url}
alt={model.name}
class="size-8 rounded-full object-cover shrink-0"
/>
<div class="w-full flex flex-col">
<div class="flex items-center gap-1">
<div class="flex-shrink-0 line-clamp-1">
{model.name}
</div>
</div>
<div class="flex items-center gap-1">
<div class=" text-xs w-full text-gray-500 bg-transparent line-clamp-1">
{model?.meta?.description ?? model.id}
</div>
</div>
</div>
</div>
</div>
<div class="flex items-center">
<button
class="self-center w-fit text-sm p-1.5 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button"
on:click={() => {
showModel = true;
}}
>
<Cog6 />
</button>
</div>
</div>
</div>

View file

@ -0,0 +1,418 @@
<script>
import { createEventDispatcher, getContext, onMount } from 'svelte';
const i18n = getContext('i18n');
const dispatch = createEventDispatcher();
import Modal from '$lib/components/common/Modal.svelte';
import { models } from '$lib/stores';
import Plus from '$lib/components/icons/Plus.svelte';
import Minus from '$lib/components/icons/Minus.svelte';
import PencilSolid from '$lib/components/icons/PencilSolid.svelte';
import { toast } from 'svelte-sonner';
export let show = false;
export let edit = false;
export let model = null;
let name = '';
let id = '';
$: if (name) {
generateId();
}
const generateId = () => {
if (!edit) {
id = name
.toLowerCase()
.replace(/[^a-z0-9]/g, '-')
.replace(/-+/g, '-')
.replace(/^-|-$/g, '');
}
};
let profileImageUrl = '/favicon.png';
let description = '';
let selectedModelId = '';
let modelIds = [];
let filterMode = 'include';
let imageInputElement;
let loading = false;
const addModelHandler = () => {
if (selectedModelId) {
modelIds = [...modelIds, selectedModelId];
selectedModelId = '';
}
};
const submitHandler = () => {
loading = true;
if (!name || !id) {
loading = false;
toast.error('Name and ID are required, please fill them out');
return;
}
if (!edit) {
if ($models.find((model) => model.name === name)) {
loading = false;
name = '';
toast.error('Model name already exists, please choose a different one');
return;
}
}
const model = {
id: id,
name: name,
meta: {
profile_image_url: profileImageUrl,
description: description || null,
model_ids: modelIds.length > 0 ? modelIds : null,
filter_mode: modelIds.length > 0 ? (filterMode ? filterMode : null) : null
}
};
dispatch('submit', model);
loading = false;
show = false;
name = '';
id = '';
profileImageUrl = '/favicon.png';
description = '';
modelIds = [];
selectedModelId = '';
};
const initModel = () => {
if (model) {
name = model.name;
id = model.id;
profileImageUrl = model.meta.profile_image_url;
description = model.meta.description;
modelIds = model.meta.model_ids || [];
filterMode = model.meta?.filter_mode ?? 'include';
}
};
$: if (show) {
initModel();
}
onMount(() => {
initModel();
});
</script>
<Modal size="sm" bind:show>
<div>
<div class=" flex justify-between dark:text-gray-100 px-5 pt-4 pb-2">
<div class=" text-lg font-medium self-center font-primary">
{#if edit}
{$i18n.t('Edit Arena Model')}
{:else}
{$i18n.t('Add Arena Model')}
{/if}
</div>
<button
class="self-center"
on:click={() => {
show = false;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button>
</div>
<div class="flex flex-col md:flex-row w-full px-4 pb-4 md:space-x-4 dark:text-gray-200">
<div class=" flex flex-col w-full sm:flex-row sm:justify-center sm:space-x-6">
<form
class="flex flex-col w-full"
on:submit|preventDefault={() => {
submitHandler();
}}
>
<div class="px-1">
<div class="flex justify-center pb-3">
<input
bind:this={imageInputElement}
type="file"
hidden
accept="image/*"
on:change={(e) => {
const files = e.target.files ?? [];
let reader = new FileReader();
reader.onload = (event) => {
let originalImageUrl = `${event.target.result}`;
const img = new Image();
img.src = originalImageUrl;
img.onload = function () {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
// Calculate the aspect ratio of the image
const aspectRatio = img.width / img.height;
// Calculate the new width and height to fit within 250x250
let newWidth, newHeight;
if (aspectRatio > 1) {
newWidth = 250 * aspectRatio;
newHeight = 250;
} else {
newWidth = 250;
newHeight = 250 / aspectRatio;
}
// Set the canvas size
canvas.width = 250;
canvas.height = 250;
// Calculate the position to center the image
const offsetX = (250 - newWidth) / 2;
const offsetY = (250 - newHeight) / 2;
// Draw the image on the canvas
ctx.drawImage(img, offsetX, offsetY, newWidth, newHeight);
// Get the base64 representation of the compressed image
const compressedSrc = canvas.toDataURL('image/jpeg');
// Display the compressed image
profileImageUrl = compressedSrc;
e.target.files = null;
};
};
if (
files.length > 0 &&
['image/gif', 'image/webp', 'image/jpeg', 'image/png'].includes(
files[0]['type']
)
) {
reader.readAsDataURL(files[0]);
}
}}
/>
<button
class="relative rounded-full w-fit h-fit shrink-0"
type="button"
on:click={() => {
imageInputElement.click();
}}
>
<img
src={profileImageUrl}
class="size-16 rounded-full object-cover shrink-0"
alt="Profile"
/>
<div
class="absolute flex justify-center rounded-full bottom-0 left-0 right-0 top-0 h-full w-full overflow-hidden bg-gray-700 bg-fixed opacity-0 transition duration-300 ease-in-out hover:opacity-50"
>
<div class="my-auto text-white">
<PencilSolid className="size-4" />
</div>
</div>
</button>
</div>
<div class="flex gap-2">
<div class="flex flex-col w-full">
<div class=" mb-0.5 text-xs text-gray-500">{$i18n.t('Name')}</div>
<div class="flex-1">
<input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
type="text"
bind:value={name}
placeholder={$i18n.t('Model Name')}
autocomplete="off"
required
/>
</div>
</div>
<div class="flex flex-col w-full">
<div class=" mb-0.5 text-xs text-gray-500">{$i18n.t('ID')}</div>
<div class="flex-1">
<input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
type="text"
bind:value={id}
placeholder={$i18n.t('Model ID')}
autocomplete="off"
required
disabled={edit}
/>
</div>
</div>
</div>
<div class="flex flex-col w-full mt-2">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Description')}</div>
<div class="flex-1">
<input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
type="text"
bind:value={description}
placeholder={$i18n.t('Enter description')}
autocomplete="off"
/>
</div>
</div>
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div class="flex flex-col w-full">
<div class="mb-1 flex justify-between">
<div class="text-xs text-gray-500">{$i18n.t('Models')}</div>
<div>
<button
class=" text-xs text-gray-500"
type="button"
on:click={() => {
filterMode = filterMode === 'include' ? 'exclude' : 'include';
}}
>
{#if filterMode === 'include'}
{$i18n.t('Include')}
{:else}
{$i18n.t('Exclude')}
{/if}
</button>
</div>
</div>
{#if modelIds.length > 0}
<div class="flex flex-col">
{#each modelIds as modelId, modelIdx}
<div class=" flex gap-2 w-full justify-between items-center">
<div class=" text-sm flex-1 py-1 rounded-lg">
{$models.find((model) => model.id === modelId)?.name}
</div>
<div class="flex-shrink-0">
<button
type="button"
on:click={() => {
modelIds = modelIds.filter((_, idx) => idx !== modelIdx);
}}
>
<Minus strokeWidth="2" className="size-3.5" />
</button>
</div>
</div>
{/each}
</div>
{:else}
<div class="text-gray-500 text-xs text-center py-2">
{$i18n.t('Leave empty to include all models or select specific models')}
</div>
{/if}
</div>
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div class="flex items-center">
<select
class="w-full py-1 text-sm rounded-lg bg-transparent {selectedModelId
? ''
: 'text-gray-500'} placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-none"
bind:value={selectedModelId}
>
<option value="">{$i18n.t('Select a model')}</option>
{#each $models.filter((m) => m?.owned_by !== 'arena') as model}
<option value={model.id} class="bg-gray-50 dark:bg-gray-700">{model.name}</option>
{/each}
</select>
<div>
<button
type="button"
on:click={() => {
addModelHandler();
}}
>
<Plus className="size-3.5" strokeWidth="2" />
</button>
</div>
</div>
</div>
<div class="flex justify-end pt-3 text-sm font-medium gap-1.5">
{#if edit}
<button
class="px-3.5 py-1.5 text-sm font-medium dark:bg-black dark:hover:bg-gray-900 dark:text-white bg-white text-black hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center"
type="button"
on:click={() => {
dispatch('delete', model);
show = false;
}}
>
{$i18n.t('Delete')}
</button>
{/if}
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed'
: ''}"
type="submit"
disabled={loading}
>
{$i18n.t('Save')}
{#if loading}
<div class="ml-2 self-center">
<svg
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div>
{/if}
</button>
</div>
</form>
</div>
</div>
</div>
</Modal>

View file

@ -137,7 +137,7 @@
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}

View file

@ -648,7 +648,7 @@
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed' ? ' cursor-not-allowed'
: ''}" : ''}"
type="submit" type="submit"

View file

@ -14,6 +14,7 @@
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Switch from '$lib/components/common/Switch.svelte'; import Switch from '$lib/components/common/Switch.svelte';
import Textarea from '$lib/components/common/Textarea.svelte';
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
@ -23,6 +24,7 @@
TASK_MODEL: '', TASK_MODEL: '',
TASK_MODEL_EXTERNAL: '', TASK_MODEL_EXTERNAL: '',
TITLE_GENERATION_PROMPT_TEMPLATE: '', TITLE_GENERATION_PROMPT_TEMPLATE: '',
TAGS_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_SEARCH_QUERY: true, ENABLE_SEARCH_QUERY: true,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: '' SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: ''
}; };
@ -60,7 +62,7 @@
> >
<div class=" overflow-y-scroll scrollbar-hidden h-full pr-1.5"> <div class=" overflow-y-scroll scrollbar-hidden h-full pr-1.5">
<div> <div>
<div class=" mb-2.5 text-sm font-medium flex"> <div class=" mb-2.5 text-sm font-medium flex items-center">
<div class=" mr-1">{$i18n.t('Set Task Model')}</div> <div class=" mr-1">{$i18n.t('Set Task Model')}</div>
<Tooltip <Tooltip
content={$i18n.t( content={$i18n.t(
@ -73,7 +75,7 @@
viewBox="0 0 24 24" viewBox="0 0 24 24"
stroke-width="1.5" stroke-width="1.5"
stroke="currentColor" stroke="currentColor"
class="w-5 h-5" class="size-3.5"
> >
<path <path
stroke-linecap="round" stroke-linecap="round"
@ -124,10 +126,22 @@
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')} content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start" placement="top-start"
> >
<textarea <Textarea
bind:value={taskConfig.TITLE_GENERATION_PROMPT_TEMPLATE} bind:value={taskConfig.TITLE_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none" placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
rows="3" />
</Tooltip>
</div>
<div class="mt-3">
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Tags Generation Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={taskConfig.TAGS_GENERATION_PROMPT_TEMPLATE}
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')} placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
/> />
</Tooltip> </Tooltip>
@ -151,10 +165,8 @@
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')} content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start" placement="top-start"
> >
<textarea <Textarea
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE} bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="3"
placeholder={$i18n.t( placeholder={$i18n.t(
'Leave empty to use the default prompt, or enter a custom prompt' 'Leave empty to use the default prompt, or enter a custom prompt'
)} )}
@ -349,7 +361,7 @@
<div class="flex justify-end text-sm font-medium"> <div class="flex justify-end text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}

View file

@ -763,7 +763,7 @@
<option value="" disabled selected>{$i18n.t('Select a model')}</option> <option value="" disabled selected>{$i18n.t('Select a model')}</option>
{/if} {/if}
{#each $models.filter((m) => !(m?.preset ?? false) && m.owned_by === 'ollama' && (selectedOllamaUrlIdx === null ? true : (m?.ollama?.urls ?? []).includes(selectedOllamaUrlIdx))) as model} {#each $models.filter((m) => !(m?.preset ?? false) && m.owned_by === 'ollama' && (selectedOllamaUrlIdx === null ? true : (m?.ollama?.urls ?? []).includes(selectedOllamaUrlIdx))) as model}
<option value={model.name} class="bg-gray-50 dark:bg-gray-700" <option value={model.id} class="bg-gray-50 dark:bg-gray-700"
>{model.name + >{model.name +
' (' + ' (' +
(model.ollama.size / 1024 ** 3).toFixed(1) + (model.ollama.size / 1024 ** 3).toFixed(1) +

View file

@ -546,12 +546,14 @@
{/if} {/if}
</div> </div>
{#if PIPELINES_LIST !== null && PIPELINES_LIST.length > 0}
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}
</button> </button>
</div> </div>
{/if}
</form> </form>

View file

@ -24,9 +24,17 @@
} }
}; };
let chatDeletion = true;
let chatEdit = true;
let chatTemporary = true;
onMount(async () => { onMount(async () => {
permissions = await getUserPermissions(localStorage.token); permissions = await getUserPermissions(localStorage.token);
chatDeletion = permissions?.chat?.deletion ?? true;
chatEdit = permissions?.chat?.editing ?? true;
chatTemporary = permissions?.chat?.temporary ?? true;
const res = await getModelFilterConfig(localStorage.token); const res = await getModelFilterConfig(localStorage.token);
if (res) { if (res) {
whitelistEnabled = res.enabled; whitelistEnabled = res.enabled;
@ -43,7 +51,13 @@
// console.log('submit'); // console.log('submit');
await setDefaultModels(localStorage.token, defaultModelId); await setDefaultModels(localStorage.token, defaultModelId);
await updateUserPermissions(localStorage.token, permissions); await updateUserPermissions(localStorage.token, {
chat: {
deletion: chatDeletion,
editing: chatEdit,
temporary: chatTemporary
}
});
await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels); await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels);
saveHandler(); saveHandler();
@ -54,127 +68,22 @@
<div> <div>
<div class=" mb-2 text-sm font-medium">{$i18n.t('User Permissions')}</div> <div class=" mb-2 text-sm font-medium">{$i18n.t('User Permissions')}</div>
<div class=" flex w-full justify-between"> <div class=" flex w-full justify-between my-2 pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Deletion')}</div> <div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Deletion')}</div>
<button <Switch bind:state={chatDeletion} />
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
permissions.chat.deletion = !(permissions?.chat?.deletion ?? true);
}}
type="button"
>
{#if permissions?.chat?.deletion ?? true}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M11.5 1A3.5 3.5 0 0 0 8 4.5V7H2.5A1.5 1.5 0 0 0 1 8.5v5A1.5 1.5 0 0 0 2.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 9.5 7V4.5a2 2 0 1 1 4 0v1.75a.75.75 0 0 0 1.5 0V4.5A3.5 3.5 0 0 0 11.5 1Z"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t('Allow')}</span>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M8 1a3.5 3.5 0 0 0-3.5 3.5V7A1.5 1.5 0 0 0 3 8.5v5A1.5 1.5 0 0 0 4.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 11.5 7V4.5A3.5 3.5 0 0 0 8 1Zm2 6V4.5a2 2 0 1 0-4 0V7h4Z"
clip-rule="evenodd"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t("Don't Allow")}</span>
{/if}
</button>
</div> </div>
<div class=" flex w-full justify-between"> <div class=" flex w-full justify-between my-2 pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Editing')}</div> <div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Editing')}</div>
<button <Switch bind:state={chatEdit} />
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
permissions.chat.editing = !(permissions?.chat?.editing ?? true);
}}
type="button"
>
{#if permissions?.chat?.editing ?? true}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M11.5 1A3.5 3.5 0 0 0 8 4.5V7H2.5A1.5 1.5 0 0 0 1 8.5v5A1.5 1.5 0 0 0 2.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 9.5 7V4.5a2 2 0 1 1 4 0v1.75a.75.75 0 0 0 1.5 0V4.5A3.5 3.5 0 0 0 11.5 1Z"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t('Allow')}</span>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M8 1a3.5 3.5 0 0 0-3.5 3.5V7A1.5 1.5 0 0 0 3 8.5v5A1.5 1.5 0 0 0 4.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 11.5 7V4.5A3.5 3.5 0 0 0 8 1Zm2 6V4.5a2 2 0 1 0-4 0V7h4Z"
clip-rule="evenodd"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t("Don't Allow")}</span>
{/if}
</button>
</div> </div>
<div class=" flex w-full justify-between"> <div class=" flex w-full justify-between my-2 pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Allow Temporary Chat')}</div> <div class=" self-center text-xs font-medium">{$i18n.t('Allow Temporary Chat')}</div>
<button <Switch bind:state={chatTemporary} />
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
permissions.chat.temporary = !(permissions?.chat?.temporary ?? true);
}}
type="button"
>
{#if permissions?.chat?.temporary ?? true}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M11.5 1A3.5 3.5 0 0 0 8 4.5V7H2.5A1.5 1.5 0 0 0 1 8.5v5A1.5 1.5 0 0 0 2.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 9.5 7V4.5a2 2 0 1 1 4 0v1.75a.75.75 0 0 0 1.5 0V4.5A3.5 3.5 0 0 0 11.5 1Z"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t('Allow')}</span>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M8 1a3.5 3.5 0 0 0-3.5 3.5V7A1.5 1.5 0 0 0 3 8.5v5A1.5 1.5 0 0 0 4.5 15h7a1.5 1.5 0 0 0 1.5-1.5v-5A1.5 1.5 0 0 0 11.5 7V4.5A3.5 3.5 0 0 0 8 1Zm2 6V4.5a2 2 0 1 0-4 0V7h4Z"
clip-rule="evenodd"
/>
</svg>
<span class="ml-2 self-center">{$i18n.t("Don't Allow")}</span>
{/if}
</button>
</div> </div>
</div> </div>
@ -210,7 +119,7 @@
<div class=" space-y-1"> <div class=" space-y-1">
<div class="mb-2"> <div class="mb-2">
<div class="flex justify-between items-center text-xs"> <div class="flex justify-between items-center text-xs my-3 pr-2">
<div class=" text-xs font-medium">{$i18n.t('Model Whitelisting')}</div> <div class=" text-xs font-medium">{$i18n.t('Model Whitelisting')}</div>
<Switch bind:state={whitelistEnabled} /> <Switch bind:state={whitelistEnabled} />
@ -296,7 +205,7 @@
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}

View file

@ -310,7 +310,7 @@
</div> </div>
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="submit" type="submit"
> >
{$i18n.t('Save')} {$i18n.t('Save')}

View file

@ -120,6 +120,11 @@
} }
}); });
if (contents.length === 0) {
showControls.set(false);
showArtifacts.set(false);
}
selectedContentIdx = contents ? contents.length - 1 : 0; selectedContentIdx = contents ? contents.length - 1 : 0;
}; };
@ -191,7 +196,7 @@
showArtifacts.set(false); showArtifacts.set(false);
}} }}
> >
<ArrowLeft className="size-3.5" /> <ArrowLeft className="size-3.5 text-gray-900 dark:text-white" />
</button> </button>
</div> </div>

View file

@ -10,7 +10,7 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/stores'; import { page } from '$app/stores';
import type { Unsubscriber, Writable } from 'svelte/store'; import { get, type Unsubscriber, type Writable } from 'svelte/store';
import type { i18n as i18nType } from 'i18next'; import type { i18n as i18nType } from 'i18next';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
@ -20,6 +20,7 @@
config, config,
type Model, type Model,
models, models,
tags as allTags,
settings, settings,
showSidebar, showSidebar,
WEBUI_NAME, WEBUI_NAME,
@ -46,7 +47,11 @@
import { generateChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion } from '$lib/apis/ollama';
import { import {
addTagById,
createNewChat, createNewChat,
deleteTagById,
deleteTagsById,
getAllTags,
getChatById, getChatById,
getChatList, getChatList,
getTagsById, getTagsById,
@ -62,7 +67,8 @@
generateTitle, generateTitle,
generateSearchQuery, generateSearchQuery,
chatAction, chatAction,
generateMoACompletion generateMoACompletion,
generateTags
} from '$lib/apis'; } from '$lib/apis';
import Banner from '../common/Banner.svelte'; import Banner from '../common/Banner.svelte';
@ -78,12 +84,15 @@
let loaded = false; let loaded = false;
const eventTarget = new EventTarget(); const eventTarget = new EventTarget();
let controlPane; let controlPane;
let controlPaneComponent;
let stopResponseFlag = false; let stopResponseFlag = false;
let autoScroll = true; let autoScroll = true;
let processing = ''; let processing = '';
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
let navbarElement;
let showEventConfirmation = false; let showEventConfirmation = false;
let eventConfirmationTitle = ''; let eventConfirmationTitle = '';
let eventConfirmationMessage = ''; let eventConfirmationMessage = '';
@ -124,7 +133,7 @@
loaded = true; loaded = true;
window.setTimeout(() => scrollToBottom(), 0); window.setTimeout(() => scrollToBottom(), 0);
const chatInput = document.getElementById('chat-textarea'); const chatInput = document.getElementById('chat-input');
chatInput?.focus(); chatInput?.focus();
} else { } else {
await goto('/'); await goto('/');
@ -174,11 +183,31 @@
message.statusHistory = [data]; message.statusHistory = [data];
} }
} else if (type === 'citation') { } else if (type === 'citation') {
if (data?.type === 'code_execution') {
// Code execution; update existing code execution by ID, or add new one.
if (!message?.code_executions) {
message.code_executions = [];
}
const existingCodeExecutionIndex = message.code_executions.findIndex(
(execution) => execution.id === data.id
);
if (existingCodeExecutionIndex !== -1) {
message.code_executions[existingCodeExecutionIndex] = data;
} else {
message.code_executions.push(data);
}
message.code_executions = message.code_executions;
} else {
// Regular citation.
if (message?.citations) { if (message?.citations) {
message.citations.push(data); message.citations.push(data);
} else { } else {
message.citations = [data]; message.citations = [data];
} }
}
} else if (type === 'message') { } else if (type === 'message') {
message.content += data.content; message.content += data.content;
} else if (type === 'replace') { } else if (type === 'replace') {
@ -243,7 +272,7 @@
if (event.data.type === 'input:prompt') { if (event.data.type === 'input:prompt') {
console.debug(event.data.text); console.debug(event.data.text);
const inputElement = document.getElementById('chat-textarea'); const inputElement = document.getElementById('chat-input');
if (inputElement) { if (inputElement) {
prompt = event.data.text; prompt = event.data.text;
@ -290,14 +319,9 @@
if (controlPane && !$mobile) { if (controlPane && !$mobile) {
try { try {
if (value) { if (value) {
const currentSize = controlPane.getSize(); controlPaneComponent.openPane();
if (currentSize === 0) {
const size = parseInt(localStorage?.chatControlsSize ?? '30');
controlPane.resize(size ? size : 30);
}
} else { } else {
controlPane.resize(0); controlPane.collapse();
} }
} catch (e) { } catch (e) {
// ignore // ignore
@ -307,10 +331,11 @@
if (!value) { if (!value) {
showCallOverlay.set(false); showCallOverlay.set(false);
showOverview.set(false); showOverview.set(false);
showArtifacts.set(false);
} }
}); });
const chatInput = document.getElementById('chat-textarea'); const chatInput = document.getElementById('chat-input');
chatInput?.focus(); chatInput?.focus();
chats.subscribe(() => {}); chats.subscribe(() => {});
@ -420,22 +445,53 @@
if ($page.url.searchParams.get('models')) { if ($page.url.searchParams.get('models')) {
selectedModels = $page.url.searchParams.get('models')?.split(','); selectedModels = $page.url.searchParams.get('models')?.split(',');
} else if ($page.url.searchParams.get('model')) { } else if ($page.url.searchParams.get('model')) {
selectedModels = $page.url.searchParams.get('model')?.split(','); const urlModels = $page.url.searchParams.get('model')?.split(',');
if (urlModels.length === 1) {
const m = $models.find((m) => m.id === urlModels[0]);
if (!m) {
const modelSelectorButton = document.getElementById('model-selector-0-button');
if (modelSelectorButton) {
modelSelectorButton.click();
await tick();
const modelSelectorInput = document.getElementById('model-search-input');
if (modelSelectorInput) {
modelSelectorInput.focus();
modelSelectorInput.value = urlModels[0];
modelSelectorInput.dispatchEvent(new Event('input'));
}
}
} else {
selectedModels = urlModels;
}
} else {
selectedModels = urlModels;
}
} else if ($settings?.models) { } else if ($settings?.models) {
selectedModels = $settings?.models; selectedModels = $settings?.models;
} else if ($config?.default_models) { } else if ($config?.default_models) {
console.log($config?.default_models.split(',') ?? ''); console.log($config?.default_models.split(',') ?? '');
selectedModels = $config?.default_models.split(','); selectedModels = $config?.default_models.split(',');
}
selectedModels = selectedModels.filter((modelId) => $models.map((m) => m.id).includes(modelId));
if (selectedModels.length === 0 || (selectedModels.length === 1 && selectedModels[0] === '')) {
if ($models.length > 0) {
selectedModels = [$models[0].id];
} else { } else {
selectedModels = ['']; selectedModels = [''];
} }
}
console.log(selectedModels);
if ($page.url.searchParams.get('youtube')) { if ($page.url.searchParams.get('youtube')) {
uploadYoutubeTranscription( uploadYoutubeTranscription(
`https://www.youtube.com/watch?v=${$page.url.searchParams.get('youtube')}` `https://www.youtube.com/watch?v=${$page.url.searchParams.get('youtube')}`
); );
} }
if ($page.url.searchParams.get('web-search') === 'true') { if ($page.url.searchParams.get('web-search') === 'true') {
webSearchEnabled = true; webSearchEnabled = true;
} }
@ -480,7 +536,7 @@
settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}')); settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
} }
const chatInput = document.getElementById('chat-textarea'); const chatInput = document.getElementById('chat-input');
setTimeout(() => chatInput?.focus(), 0); setTimeout(() => chatInput?.focus(), 0);
}; };
@ -492,7 +548,10 @@
}); });
if (chat) { if (chat) {
tags = await getTags(); tags = await getTagsById(localStorage.token, $chatId).catch(async (error) => {
return [];
});
const chatContent = chat.chat; const chatContent = chat.chat;
if (chatContent) { if (chatContent) {
@ -736,53 +795,62 @@
////////////////////////// //////////////////////////
const submitPrompt = async (userPrompt, { _raw = false } = {}) => { const submitPrompt = async (userPrompt, { _raw = false } = {}) => {
let _responses = []; console.log('submitPrompt', userPrompt, $chatId);
console.log('submitPrompt', $chatId);
const messages = createMessagesList(history.currentId);
const messages = createMessagesList(history.currentId);
selectedModels = selectedModels.map((modelId) => selectedModels = selectedModels.map((modelId) =>
$models.map((m) => m.id).includes(modelId) ? modelId : '' $models.map((m) => m.id).includes(modelId) ? modelId : ''
); );
if (userPrompt === '') {
toast.error($i18n.t('Please enter a prompt'));
return;
}
if (selectedModels.includes('')) { if (selectedModels.includes('')) {
toast.error($i18n.t('Model not selected')); toast.error($i18n.t('Model not selected'));
} else if (messages.length != 0 && messages.at(-1).done != true) { return;
}
if (messages.length != 0 && messages.at(-1).done != true) {
// Response not done // Response not done
console.log('wait'); return;
} else if (messages.length != 0 && messages.at(-1).error) { }
if (messages.length != 0 && messages.at(-1).error) {
// Error in response // Error in response
toast.error( toast.error($i18n.t(`Oops! There was an error in the previous response.`));
$i18n.t( return;
`Oops! There was an error in the previous response. Please try again or contact admin.` }
) if (
);
} else if (
files.length > 0 && files.length > 0 &&
files.filter((file) => file.type !== 'image' && file.status === 'uploading').length > 0 files.filter((file) => file.type !== 'image' && file.status === 'uploading').length > 0
) { ) {
// Upload not done
toast.error( toast.error(
$i18n.t( $i18n.t(`Oops! There are files still uploading. Please wait for the upload to complete.`)
`Oops! Hold tight! Your files are still in the processing oven. We're cooking them up to perfection. Please be patient and we'll let you know once they're ready.`
)
); );
} else if ( return;
}
if (
($config?.file?.max_count ?? null) !== null && ($config?.file?.max_count ?? null) !== null &&
files.length + chatFiles.length > $config?.file?.max_count files.length + chatFiles.length > $config?.file?.max_count
) { ) {
console.log(chatFiles.length, files.length);
toast.error( toast.error(
$i18n.t(`You can only chat with a maximum of {{maxCount}} file(s) at a time.`, { $i18n.t(`You can only chat with a maximum of {{maxCount}} file(s) at a time.`, {
maxCount: $config?.file?.max_count maxCount: $config?.file?.max_count
}) })
); );
} else { return;
// Reset chat input textarea }
const chatTextAreaElement = document.getElementById('chat-textarea');
if (chatTextAreaElement) { let _responses = [];
chatTextAreaElement.value = ''; prompt = '';
chatTextAreaElement.style.height = ''; await tick();
// Reset chat input textarea
const chatInputContainer = document.getElementById('chat-input-container');
if (chatInputContainer) {
chatInputContainer.value = '';
chatInputContainer.style.height = '';
} }
const _files = JSON.parse(JSON.stringify(files)); const _files = JSON.parse(JSON.stringify(files));
@ -820,8 +888,12 @@
// Wait until history/message have been updated // Wait until history/message have been updated
await tick(); await tick();
// focus on chat input
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
_responses = await sendPrompt(userPrompt, userMessageId, { newChat: true }); _responses = await sendPrompt(userPrompt, userMessageId, { newChat: true });
}
return _responses; return _responses;
}; };
@ -942,10 +1014,10 @@
} }
let _response = null; let _response = null;
if (model?.owned_by === 'openai') { if (model?.owned_by === 'ollama') {
_response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
} else if (model) {
_response = await sendPromptOllama(model, prompt, responseMessageId, _chatId); _response = await sendPromptOllama(model, prompt, responseMessageId, _chatId);
} else if (model) {
_response = await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
} }
_responses.push(_response); _responses.push(_response);
@ -998,7 +1070,7 @@
// Prepare the base message object // Prepare the base message object
const baseMessage = { const baseMessage = {
role: message.role, role: message.role,
content: message.content content: message?.merged?.content ?? message.content
}; };
// Extract and format image URLs if any exist // Extract and format image URLs if any exist
@ -1341,8 +1413,13 @@
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
if (messages.length == 2 && messages.at(-1).content !== '' && selectedModels[0] === model.id) { if (messages.length == 2 && messages.at(-1).content !== '' && selectedModels[0] === model.id) {
window.history.replaceState(history.state, '', `/c/${_chatId}`); window.history.replaceState(history.state, '', `/c/${_chatId}`);
const title = await generateChatTitle(userPrompt);
const title = await generateChatTitle(messages);
await setChatTitle(_chatId, title); await setChatTitle(_chatId, title);
if ($settings?.autoTags ?? true) {
await setChatTags(messages);
}
} }
return _response; return _response;
@ -1458,10 +1535,7 @@
content: [ content: [
{ {
type: 'text', type: 'text',
text: text: message?.merged?.content ?? message.content
arr.length - 1 !== idx
? message.content
: (message?.raContent ?? message.content)
}, },
...message.files ...message.files
.filter((file) => file.type === 'image') .filter((file) => file.type === 'image')
@ -1474,10 +1548,7 @@
] ]
} }
: { : {
content: content: message?.merged?.content ?? message.content
arr.length - 1 !== idx
? message.content
: (message?.raContent ?? message.content)
}) })
})), })),
seed: params?.seed ?? $settings?.params?.seed ?? undefined, seed: params?.seed ?? $settings?.params?.seed ?? undefined,
@ -1518,7 +1589,7 @@
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
for await (const update of textStream) { for await (const update of textStream) {
const { value, done, citations, error, usage } = update; const { value, done, citations, selectedModelId, error, usage } = update;
if (error) { if (error) {
await handleOpenAIError(error, null, model, responseMessage); await handleOpenAIError(error, null, model, responseMessage);
break; break;
@ -1538,6 +1609,12 @@
responseMessage.info = { ...usage, openai: true, usage }; responseMessage.info = { ...usage, openai: true, usage };
} }
if (selectedModelId) {
responseMessage.selectedModelId = selectedModelId;
responseMessage.arena = true;
continue;
}
if (citations) { if (citations) {
responseMessage.citations = citations; responseMessage.citations = citations;
// Only remove status if it was initially set // Only remove status if it was initially set
@ -1655,8 +1732,13 @@
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
if (messages.length == 2 && selectedModels[0] === model.id) { if (messages.length == 2 && selectedModels[0] === model.id) {
window.history.replaceState(history.state, '', `/c/${_chatId}`); window.history.replaceState(history.state, '', `/c/${_chatId}`);
const title = await generateChatTitle(userPrompt);
const title = await generateChatTitle(messages);
await setChatTitle(_chatId, title); await setChatTitle(_chatId, title);
if ($settings?.autoTags ?? true) {
await setChatTags(messages);
}
} }
return _response; return _response;
@ -1812,21 +1894,21 @@
} }
}; };
const generateChatTitle = async (userPrompt) => { const generateChatTitle = async (messages) => {
if ($settings?.title?.auto ?? true) { if ($settings?.title?.auto ?? true) {
const title = await generateTitle( const lastMessage = messages.at(-1);
localStorage.token, const modelId = selectedModels[0];
selectedModels[0],
userPrompt, const title = await generateTitle(localStorage.token, modelId, messages, $chatId).catch(
$chatId (error) => {
).catch((error) => {
console.error(error); console.error(error);
return 'New Chat'; return 'New Chat';
}); }
);
return title; return title;
} else { } else {
return `${userPrompt}`; return 'New Chat';
} }
}; };
@ -1843,6 +1925,40 @@
} }
}; };
const setChatTags = async (messages) => {
if (!$temporaryChatEnabled) {
const currentTags = await getTagsById(localStorage.token, $chatId);
if (currentTags.length > 0) {
const res = await deleteTagsById(localStorage.token, $chatId);
if (res) {
allTags.set(await getAllTags(localStorage.token));
}
}
const lastMessage = messages.at(-1);
const modelId = selectedModels[0];
let generatedTags = await generateTags(localStorage.token, modelId, messages, $chatId).catch(
(error) => {
console.error(error);
return [];
}
);
generatedTags = generatedTags.filter(
(tag) => !currentTags.find((t) => t.id === tag.replaceAll(' ', '_').toLowerCase())
);
console.log(generatedTags);
for (const tag of generatedTags) {
await addTagById(localStorage.token, $chatId, tag);
}
chat = await getChatById(localStorage.token, $chatId);
allTags.set(await getAllTags(localStorage.token));
}
};
const getWebSearchResults = async ( const getWebSearchResults = async (
model: string, model: string,
parentId: string, parentId: string,
@ -1928,12 +2044,6 @@
} }
}; };
const getTags = async () => {
return await getTagsById(localStorage.token, $chatId).catch(async (error) => {
return [];
});
};
const initChatHandler = async () => { const initChatHandler = async () => {
if (!$temporaryChatEnabled) { if (!$temporaryChatEnabled) {
chat = await createNewChat(localStorage.token, { chat = await createNewChat(localStorage.token, {
@ -2025,6 +2135,7 @@
{/if} {/if}
<Navbar <Navbar
bind:this={navbarElement}
chat={{ chat={{
id: $chatId, id: $chatId,
chat: { chat: {
@ -2161,9 +2272,8 @@
}} }}
on:submit={async (e) => { on:submit={async (e) => {
if (e.detail) { if (e.detail) {
prompt = '';
await tick(); await tick();
submitPrompt(e.detail); submitPrompt(e.detail.replaceAll('\n\n', '\n'));
} }
}} }}
/> />
@ -2206,9 +2316,8 @@
}} }}
on:submit={async (e) => { on:submit={async (e) => {
if (e.detail) { if (e.detail) {
prompt = '';
await tick(); await tick();
submitPrompt(e.detail); submitPrompt(e.detail.replaceAll('\n\n', '\n'));
} }
}} }}
/> />
@ -2218,6 +2327,7 @@
</Pane> </Pane>
<ChatControls <ChatControls
bind:this={controlPaneComponent}
bind:history bind:history
bind:chatFiles bind:chatFiles
bind:params bind:params

View file

@ -1,6 +1,7 @@
<script lang="ts"> <script lang="ts">
import { SvelteFlowProvider } from '@xyflow/svelte'; import { SvelteFlowProvider } from '@xyflow/svelte';
import { slide } from 'svelte/transition'; import { slide } from 'svelte/transition';
import { Pane, PaneResizer } from 'paneforge';
import { onDestroy, onMount, tick } from 'svelte'; import { onDestroy, onMount, tick } from 'svelte';
import { mobile, showControls, showCallOverlay, showOverview, showArtifacts } from '$lib/stores'; import { mobile, showControls, showCallOverlay, showOverview, showArtifacts } from '$lib/stores';
@ -10,9 +11,9 @@
import CallOverlay from './MessageInput/CallOverlay.svelte'; import CallOverlay from './MessageInput/CallOverlay.svelte';
import Drawer from '../common/Drawer.svelte'; import Drawer from '../common/Drawer.svelte';
import Overview from './Overview.svelte'; import Overview from './Overview.svelte';
import { Pane, PaneResizer } from 'paneforge';
import EllipsisVertical from '../icons/EllipsisVertical.svelte'; import EllipsisVertical from '../icons/EllipsisVertical.svelte';
import Artifacts from './Artifacts.svelte'; import Artifacts from './Artifacts.svelte';
import { min } from '@floating-ui/utils';
export let history; export let history;
export let models = []; export let models = [];
@ -35,6 +36,16 @@
let largeScreen = false; let largeScreen = false;
let dragged = false; let dragged = false;
let minSize = 0;
export const openPane = () => {
if (parseInt(localStorage?.chatControlsSize)) {
pane.resize(parseInt(localStorage?.chatControlsSize));
} else {
pane.resize(minSize);
}
};
const handleMediaQuery = async (e) => { const handleMediaQuery = async (e) => {
if (e.matches) { if (e.matches) {
largeScreen = true; largeScreen = true;
@ -71,6 +82,32 @@
mediaQuery.addEventListener('change', handleMediaQuery); mediaQuery.addEventListener('change', handleMediaQuery);
handleMediaQuery(mediaQuery); handleMediaQuery(mediaQuery);
// Select the container element you want to observe
const container = document.getElementById('chat-container');
// initialize the minSize based on the container width
minSize = Math.floor((350 / container.clientWidth) * 100);
// Create a new ResizeObserver instance
const resizeObserver = new ResizeObserver((entries) => {
for (let entry of entries) {
const width = entry.contentRect.width;
// calculate the percentage of 200px
const percentage = (350 / width) * 100;
// set the minSize to the percentage, must be an integer
minSize = Math.floor(percentage);
if ($showControls) {
if (pane && pane.isExpanded() && pane.getSize() < minSize) {
pane.resize(minSize);
}
}
}
});
// Start observing the container's size changes
resizeObserver.observe(container);
document.addEventListener('mousedown', onMouseDown); document.addEventListener('mousedown', onMouseDown);
document.addEventListener('mouseup', onMouseUp); document.addEventListener('mouseup', onMouseUp);
}); });
@ -163,23 +200,29 @@
</div> </div>
</PaneResizer> </PaneResizer>
{/if} {/if}
<Pane <Pane
bind:pane bind:pane
defaultSize={$showControls defaultSize={0}
? parseInt(localStorage?.chatControlsSize ?? '30')
? parseInt(localStorage?.chatControlsSize ?? '30')
: 30
: 0}
onResize={(size) => { onResize={(size) => {
if (size === 0) { console.log('size', size, minSize);
showControls.set(false);
} else { if ($showControls && pane.isExpanded()) {
if (!$showControls) { if (size < minSize) {
showControls.set(true); pane.resize(minSize);
} }
if (size < minSize) {
localStorage.chatControlsSize = 0;
} else {
localStorage.chatControlsSize = size; localStorage.chatControlsSize = size;
} }
}
}} }}
onCollapse={() => {
showControls.set(false);
}}
collapsible={true}
class="pt-8" class="pt-8"
> >
{#if $showControls} {#if $showControls}
@ -187,7 +230,7 @@
<div <div
class="w-full {($showOverview || $showArtifacts) && !$showCallOverlay class="w-full {($showOverview || $showArtifacts) && !$showCallOverlay
? ' ' ? ' '
: 'px-4 py-4 bg-white dark:shadow-lg dark:bg-gray-850 border border-gray-50 dark:border-gray-800'} rounded-lg z-40 pointer-events-auto overflow-y-auto scrollbar-hidden" : 'px-4 py-4 bg-white dark:shadow-lg dark:bg-gray-850 border border-gray-50 dark:border-gray-850'} rounded-xl z-40 pointer-events-auto overflow-y-auto scrollbar-hidden"
> >
{#if $showCallOverlay} {#if $showCallOverlay}
<div class="w-full h-full flex justify-center"> <div class="w-full h-full flex justify-center">

View file

@ -82,8 +82,8 @@
> >
<div> <div>
<div class=" capitalize line-clamp-1" in:fade={{ duration: 200 }}> <div class=" capitalize line-clamp-1" in:fade={{ duration: 200 }}>
{#if models[selectedModelIdx]?.info} {#if models[selectedModelIdx]?.name}
{models[selectedModelIdx]?.info?.name} {models[selectedModelIdx]?.name}
{:else} {:else}
{$i18n.t('Hello, {{name}}', { name: $user.name })} {$i18n.t('Hello, {{name}}', { name: $user.name })}
{/if} {/if}

View file

@ -30,7 +30,7 @@
<div class=" dark:text-gray-200 text-sm font-primary py-0.5 px-0.5"> <div class=" dark:text-gray-200 text-sm font-primary py-0.5 px-0.5">
{#if chatFiles.length > 0} {#if chatFiles.length > 0}
<Collapsible title={$i18n.t('Files')} open={true}> <Collapsible title={$i18n.t('Files')} open={true} buttonClassName="w-full">
<div class="flex flex-col gap-1 mt-1.5" slot="content"> <div class="flex flex-col gap-1 mt-1.5" slot="content">
{#each chatFiles as file, fileIdx} {#each chatFiles as file, fileIdx}
<FileItem <FileItem
@ -56,31 +56,31 @@
</div> </div>
</Collapsible> </Collapsible>
<hr class="my-2 border-gray-100 dark:border-gray-800" /> <hr class="my-2 border-gray-50 dark:border-gray-700/10" />
{/if} {/if}
<Collapsible title={$i18n.t('Valves')}> <Collapsible title={$i18n.t('Valves')} buttonClassName="w-full">
<div class="text-sm mt-1.5" slot="content"> <div class="text-sm" slot="content">
<Valves /> <Valves />
</div> </div>
</Collapsible> </Collapsible>
<hr class="my-2 border-gray-100 dark:border-gray-800" /> <hr class="my-2 border-gray-50 dark:border-gray-700/10" />
<Collapsible title={$i18n.t('System Prompt')} open={true}> <Collapsible title={$i18n.t('System Prompt')} open={true} buttonClassName="w-full">
<div class=" mt-1.5" slot="content"> <div class="" slot="content">
<textarea <textarea
bind:value={params.system} bind:value={params.system}
class="w-full rounded-lg px-3.5 py-2.5 text-sm dark:text-gray-300 dark:bg-gray-850 border border-gray-100 dark:border-gray-800 outline-none resize-none" class="w-full text-xs py-1.5 bg-transparent outline-none resize-none"
rows="4" rows="4"
placeholder={$i18n.t('Enter system prompt')} placeholder={$i18n.t('Enter system prompt')}
/> />
</div> </div>
</Collapsible> </Collapsible>
<hr class="my-2 border-gray-100 dark:border-gray-800" /> <hr class="my-2 border-gray-50 dark:border-gray-700/10" />
<Collapsible title={$i18n.t('Advanced Params')} open={true}> <Collapsible title={$i18n.t('Advanced Params')} open={true} buttonClassName="w-full">
<div class="text-sm mt-1.5" slot="content"> <div class="text-sm mt-1.5" slot="content">
<div> <div>
<AdvancedParams admin={$user?.role === 'admin'} bind:params /> <AdvancedParams admin={$user?.role === 'admin'} bind:params />

View file

@ -1,5 +1,7 @@
<script lang="ts"> <script lang="ts">
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { v4 as uuidv4 } from 'uuid';
import { onMount, tick, getContext, createEventDispatcher, onDestroy } from 'svelte'; import { onMount, tick, getContext, createEventDispatcher, onDestroy } from 'svelte';
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
@ -29,6 +31,7 @@
import FilesOverlay from './MessageInput/FilesOverlay.svelte'; import FilesOverlay from './MessageInput/FilesOverlay.svelte';
import Commands from './MessageInput/Commands.svelte'; import Commands from './MessageInput/Commands.svelte';
import XMark from '../icons/XMark.svelte'; import XMark from '../icons/XMark.svelte';
import RichTextInput from '../common/RichTextInput.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -52,9 +55,10 @@
let recording = false; let recording = false;
let chatTextAreaElement: HTMLTextAreaElement; let chatInputContainerElement;
let filesInputElement; let chatInputElement;
let filesInputElement;
let commandsElement; let commandsElement;
let inputFiles; let inputFiles;
@ -69,9 +73,10 @@
); );
$: if (prompt) { $: if (prompt) {
if (chatTextAreaElement) { if (chatInputContainerElement) {
chatTextAreaElement.style.height = ''; chatInputContainerElement.style.height = '';
chatTextAreaElement.style.height = Math.min(chatTextAreaElement.scrollHeight, 200) + 'px'; chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
} }
} }
@ -86,6 +91,7 @@
const uploadFileHandler = async (file) => { const uploadFileHandler = async (file) => {
console.log(file); console.log(file);
const tempItemId = uuidv4();
const fileItem = { const fileItem = {
type: 'file', type: 'file',
file: '', file: '',
@ -95,10 +101,16 @@
collection_name: '', collection_name: '',
status: 'uploading', status: 'uploading',
size: file.size, size: file.size,
error: '' error: '',
itemId: tempItemId
}; };
files = [...files, fileItem];
if (fileItem.size == 0) {
toast.error($i18n.t('You cannot upload an empty file.'));
return null;
}
files = [...files, fileItem];
// Check if the file is an audio file and transcribe/convert it to text file // Check if the file is an audio file and transcribe/convert it to text file
if (['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/x-m4a'].includes(file['type'])) { if (['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/x-m4a'].includes(file['type'])) {
const res = await transcribeAudio(localStorage.token, file).catch((error) => { const res = await transcribeAudio(localStorage.token, file).catch((error) => {
@ -121,6 +133,10 @@
const uploadedFile = await uploadFile(localStorage.token, file); const uploadedFile = await uploadFile(localStorage.token, file);
if (uploadedFile) { if (uploadedFile) {
if (uploadedFile.error) {
toast.warning(uploadedFile.error);
}
fileItem.status = 'uploaded'; fileItem.status = 'uploaded';
fileItem.file = uploadedFile; fileItem.file = uploadedFile;
fileItem.id = uploadedFile.id; fileItem.id = uploadedFile.id;
@ -129,11 +145,11 @@
files = files; files = files;
} else { } else {
files = files.filter((item) => item.status !== null); files = files.filter((item) => item?.itemId !== tempItemId);
} }
} catch (e) { } catch (e) {
toast.error(e); toast.error(e);
files = files.filter((item) => item.status !== null); files = files.filter((item) => item?.itemId !== tempItemId);
} }
}; };
@ -184,7 +200,13 @@
const onDragOver = (e) => { const onDragOver = (e) => {
e.preventDefault(); e.preventDefault();
// Check if a file is being dragged.
if (e.dataTransfer?.types?.includes('Files')) {
dragged = true; dragged = true;
} else {
dragged = false;
}
}; };
const onDragLeave = () => { const onDragLeave = () => {
@ -200,8 +222,6 @@
if (inputFiles && inputFiles.length > 0) { if (inputFiles && inputFiles.length > 0) {
console.log(inputFiles); console.log(inputFiles);
inputFilesHandler(inputFiles); inputFilesHandler(inputFiles);
} else {
toast.error($i18n.t(`File not found.`));
} }
} }
@ -209,7 +229,10 @@
}; };
onMount(() => { onMount(() => {
window.setTimeout(() => chatTextAreaElement?.focus(), 0); window.setTimeout(() => {
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
}, 0);
window.addEventListener('keydown', handleKeyDown); window.addEventListener('keydown', handleKeyDown);
@ -312,7 +335,8 @@
atSelectedModel = data.data; atSelectedModel = data.data;
} }
chatTextAreaElement?.focus(); const chatInputElement = document.getElementById('chat-input');
chatInputElement?.focus();
}} }}
/> />
</div> </div>
@ -347,16 +371,16 @@
recording = false; recording = false;
await tick(); await tick();
document.getElementById('chat-textarea')?.focus(); document.getElementById('chat-input')?.focus();
}} }}
on:confirm={async (e) => { on:confirm={async (e) => {
const response = e.detail; const { text, filename } = e.detail;
prompt = `${prompt}${response} `; prompt = `${prompt}${text} `;
recording = false; recording = false;
await tick(); await tick();
document.getElementById('chat-textarea')?.focus(); document.getElementById('chat-input')?.focus();
if ($settings?.speechAutoSend ?? false) { if ($settings?.speechAutoSend ?? false) {
dispatch('submit', prompt); dispatch('submit', prompt);
@ -474,7 +498,9 @@
}} }}
onClose={async () => { onClose={async () => {
await tick(); await tick();
chatTextAreaElement?.focus();
const chatInput = document.getElementById('chat-input');
chatInput?.focus();
}} }}
> >
<button <button
@ -496,9 +522,178 @@
</InputMenu> </InputMenu>
</div> </div>
{#if $settings?.richTextInput ?? true}
<div
bind:this={chatInputContainerElement}
id="chat-input-container"
class="scrollbar-hidden text-left bg-gray-50 dark:bg-gray-850 dark:text-gray-100 outline-none w-full py-2.5 px-1 rounded-xl resize-none h-[48px] overflow-auto"
>
<RichTextInput
bind:this={chatInputElement}
id="chat-input"
trim={true}
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
bind:value={prompt}
shiftEnter={!$mobile ||
!(
'ontouchstart' in window ||
navigator.maxTouchPoints > 0 ||
navigator.msMaxTouchPoints > 0
)}
on:enter={async (e) => {
if (prompt !== '') {
dispatch('submit', prompt);
}
}}
on:input={async (e) => {
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
}
}}
on:focus={async (e) => {
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
}
}}
on:keypress={(e) => {
e = e.detail.event;
}}
on:keydown={async (e) => {
e = e.detail.event;
if (chatInputContainerElement) {
chatInputContainerElement.style.height = '';
chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
}
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
const commandsContainerElement =
document.getElementById('commands-container');
// Command/Ctrl + Shift + Enter to submit a message pair
if (isCtrlPressed && e.key === 'Enter' && e.shiftKey) {
e.preventDefault();
createMessagePair(prompt);
}
// Check if Ctrl + R is pressed
if (prompt === '' && isCtrlPressed && e.key.toLowerCase() === 'r') {
e.preventDefault();
console.log('regenerate');
const regenerateButton = [
...document.getElementsByClassName('regenerate-response-button')
]?.at(-1);
regenerateButton?.click();
}
if (prompt === '' && e.key == 'ArrowUp') {
e.preventDefault();
const userMessageElement = [
...document.getElementsByClassName('user-message')
]?.at(-1);
const editButton = [
...document.getElementsByClassName('edit-user-message-button')
]?.at(-1);
console.log(userMessageElement);
userMessageElement.scrollIntoView({ block: 'center' });
editButton?.click();
}
if (commandsContainerElement && e.key === 'ArrowUp') {
e.preventDefault();
commandsElement.selectUp();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton.scrollIntoView({ block: 'center' });
}
if (commandsContainerElement && e.key === 'ArrowDown') {
e.preventDefault();
commandsElement.selectDown();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton.scrollIntoView({ block: 'center' });
}
if (commandsContainerElement && e.key === 'Enter') {
e.preventDefault();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
if (e.shiftKey) {
prompt = `${prompt}\n`;
} else if (commandOptionButton) {
commandOptionButton?.click();
} else {
document.getElementById('send-message-button')?.click();
}
}
if (commandsContainerElement && e.key === 'Tab') {
e.preventDefault();
const commandOptionButton = [
...document.getElementsByClassName('selected-command-option-button')
]?.at(-1);
commandOptionButton?.click();
}
if (e.key === 'Escape') {
console.log('Escape');
atSelectedModel = undefined;
}
}}
on:paste={async (e) => {
e = e.detail.event;
console.log(e);
const clipboardData = e.clipboardData || window.clipboardData;
if (clipboardData && clipboardData.items) {
for (const item of clipboardData.items) {
if (item.type.indexOf('image') !== -1) {
const blob = item.getAsFile();
const reader = new FileReader();
reader.onload = function (e) {
files = [
...files,
{
type: 'image',
url: `${e.target.result}`
}
];
};
reader.readAsDataURL(blob);
}
}
}
}}
/>
</div>
{:else}
<textarea <textarea
id="chat-textarea" id="chat-input"
bind:this={chatTextAreaElement} bind:this={chatInputElement}
class="scrollbar-hidden bg-gray-50 dark:bg-gray-850 dark:text-gray-100 outline-none w-full py-3 px-1 rounded-xl resize-none h-[48px]" class="scrollbar-hidden bg-gray-50 dark:bg-gray-850 dark:text-gray-100 outline-none w-full py-3 px-1 rounded-xl resize-none h-[48px]"
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')} placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
bind:value={prompt} bind:value={prompt}
@ -524,7 +719,8 @@
}} }}
on:keydown={async (e) => { on:keydown={async (e) => {
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
const commandsContainerElement = document.getElementById('commands-container'); const commandsContainerElement =
document.getElementById('commands-container');
// Command/Ctrl + Shift + Enter to submit a message pair // Command/Ctrl + Shift + Enter to submit a message pair
if (isCtrlPressed && e.key === 'Enter' && e.shiftKey) { if (isCtrlPressed && e.key === 'Enter' && e.shiftKey) {
@ -667,6 +863,7 @@
} }
}} }}
/> />
{/if}
<div class="self-end mb-2 flex space-x-1 mr-1"> <div class="self-end mb-2 flex space-x-1 mr-1">
{#if !history?.currentId || history.messages[history.currentId]?.done == true} {#if !history?.currentId || history.messages[history.currentId]?.done == true}
@ -796,6 +993,7 @@
{/if} {/if}
{:else} {:else}
<div class=" flex items-center mb-1.5"> <div class=" flex items-center mb-1.5">
<Tooltip content={$i18n.t('Stop')}>
<button <button
class="bg-white hover:bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-1.5" class="bg-white hover:bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-1.5"
on:click={() => { on:click={() => {
@ -815,6 +1013,7 @@
/> />
</svg> </svg>
</button> </button>
</Tooltip>
</div> </div>
{/if} {/if}
</div> </div>

View file

@ -25,17 +25,17 @@
}; };
let command = ''; let command = '';
$: command = (prompt?.trim() ?? '').split(' ')?.at(-1) ?? ''; $: command = prompt?.split('\n').pop()?.split(' ')?.pop() ?? '';
</script> </script>
{#if ['/', '#', '@'].includes(command?.charAt(0))} {#if ['/', '#', '@'].includes(command?.charAt(0)) || '\\#' === command.slice(0, 2)}
{#if command?.charAt(0) === '/'} {#if command?.charAt(0) === '/'}
<Prompts bind:this={commandElement} bind:prompt bind:files {command} /> <Prompts bind:this={commandElement} bind:prompt bind:files {command} />
{:else if command?.charAt(0) === '#'} {:else if (command?.charAt(0) === '#' && command.startsWith('#') && !command.includes('# ')) || ('\\#' === command.slice(0, 2) && command.startsWith('#') && !command.includes('# '))}
<Knowledge <Knowledge
bind:this={commandElement} bind:this={commandElement}
bind:prompt bind:prompt
{command} command={command.includes('\\#') ? command.slice(2) : command}
on:youtube={(e) => { on:youtube={(e) => {
console.log(e); console.log(e);
dispatch('upload', { dispatch('upload', {
@ -55,7 +55,6 @@
files = [ files = [
...files, ...files,
{ {
type: e?.detail?.meta?.document ? 'file' : 'collection',
...e.detail, ...e.detail,
status: 'processed' status: 'processed'
} }

View file

@ -2,6 +2,10 @@
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import Fuse from 'fuse.js'; import Fuse from 'fuse.js';
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
dayjs.extend(relativeTime);
import { createEventDispatcher, tick, getContext, onMount } from 'svelte'; import { createEventDispatcher, tick, getContext, onMount } from 'svelte';
import { removeLastWordFromString, isValidHttpUrl } from '$lib/utils'; import { removeLastWordFromString, isValidHttpUrl } from '$lib/utils';
import { knowledge } from '$lib/stores'; import { knowledge } from '$lib/stores';
@ -42,7 +46,7 @@
dispatch('select', item); dispatch('select', item);
prompt = removeLastWordFromString(prompt, command); prompt = removeLastWordFromString(prompt, command);
const chatInputElement = document.getElementById('chat-textarea'); const chatInputElement = document.getElementById('chat-input');
await tick(); await tick();
chatInputElement?.focus(); chatInputElement?.focus();
@ -53,7 +57,7 @@
dispatch('url', url); dispatch('url', url);
prompt = removeLastWordFromString(prompt, command); prompt = removeLastWordFromString(prompt, command);
const chatInputElement = document.getElementById('chat-textarea'); const chatInputElement = document.getElementById('chat-input');
await tick(); await tick();
chatInputElement?.focus(); chatInputElement?.focus();
@ -64,7 +68,7 @@
dispatch('youtube', url); dispatch('youtube', url);
prompt = removeLastWordFromString(prompt, command); prompt = removeLastWordFromString(prompt, command);
const chatInputElement = document.getElementById('chat-textarea'); const chatInputElement = document.getElementById('chat-input');
await tick(); await tick();
chatInputElement?.focus(); chatInputElement?.focus();
@ -72,7 +76,13 @@
}; };
onMount(() => { onMount(() => {
let legacy_documents = $knowledge.filter((item) => item?.meta?.document); let legacy_documents = $knowledge
.filter((item) => item?.meta?.document)
.map((item) => ({
...item,
type: 'file'
}));
let legacy_collections = let legacy_collections =
legacy_documents.length > 0 legacy_documents.length > 0
? [ ? [
@ -101,12 +111,44 @@
] ]
: []; : [];
items = [...$knowledge, ...legacy_collections].map((item) => { let collections = $knowledge
.filter((item) => !item?.meta?.document)
.map((item) => ({
...item,
type: 'collection'
}));
let collection_files =
$knowledge.length > 0
? [
...$knowledge
.reduce((a, item) => {
return [
...new Set([
...a,
...(item?.files ?? []).map((file) => ({
...file,
collection: { name: item.name, description: item.description }
}))
])
];
}, [])
.map((file) => ({
...file,
name: file?.meta?.name,
description: `${file?.collection?.name} - ${file?.collection?.description}`,
type: 'file'
}))
]
: [];
items = [...collections, ...collection_files, ...legacy_collections, ...legacy_documents].map(
(item) => {
return { return {
...item, ...item,
...(item?.legacy || item?.meta?.legacy || item?.meta?.document ? { legacy: true } : {}) ...(item?.legacy || item?.meta?.legacy || item?.meta?.document ? { legacy: true } : {})
}; };
}); }
);
fuse = new Fuse(items, { fuse = new Fuse(items, {
keys: ['name', 'description'] keys: ['name', 'description']
@ -117,20 +159,17 @@
{#if filteredItems.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')} {#if filteredItems.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
<div <div
id="commands-container" id="commands-container"
class="pl-2 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10" class="pl-3 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
> >
<div class="flex w-full dark:border dark:border-gray-850 rounded-lg"> <div class="flex w-full rounded-xl border border-gray-50 dark:border-gray-850">
<div class=" bg-gray-50 dark:bg-gray-850 w-10 rounded-l-lg text-center">
<div class=" text-lg font-medium mt-2">#</div>
</div>
<div <div
class="max-h-60 flex flex-col w-full rounded-r-xl bg-white dark:bg-gray-900 dark:text-gray-100" class="max-h-60 flex flex-col w-full rounded-xl bg-white dark:bg-gray-900 dark:text-gray-100"
> >
<div class="m-1 overflow-y-auto p-1 rounded-r-xl space-y-0.5 scrollbar-hidden"> <div class="m-1 overflow-y-auto p-1 rounded-r-xl space-y-0.5 scrollbar-hidden">
{#each filteredItems as item, idx} {#each filteredItems as item, idx}
<button <button
class=" px-3 py-1.5 rounded-xl w-full text-left {idx === selectedIdx class=" px-3 py-1.5 rounded-xl w-full text-left flex justify-between items-center {idx ===
selectedIdx
? ' bg-gray-50 dark:bg-gray-850 dark:text-gray-100 selected-command-option-button' ? ' bg-gray-50 dark:bg-gray-850 dark:text-gray-100 selected-command-option-button'
: ''}" : ''}"
type="button" type="button"
@ -141,38 +180,87 @@
on:mousemove={() => { on:mousemove={() => {
selectedIdx = idx; selectedIdx = idx;
}} }}
on:focus={() => {}}
> >
<div>
<div class=" font-medium text-black dark:text-gray-100 flex items-center gap-1"> <div class=" font-medium text-black dark:text-gray-100 flex items-center gap-1">
{#if item.legacy} {#if item.legacy}
<div <div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1" class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
> >
Legacy Legacy
</div> </div>
{:else if item?.meta?.document} {:else if item?.meta?.document}
<div <div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1" class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
> >
Document Document
</div> </div>
{:else if item?.type === 'file'}
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
>
File
</div>
{:else} {:else}
<div <div
class="bg-green-500/20 text-green-700 dark:text-green-200 rounded uppercase text-xs font-bold px-1" class="bg-green-500/20 text-green-700 dark:text-green-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
> >
Collection Collection
</div> </div>
{/if} {/if}
<div class="line-clamp-1"> <div class="line-clamp-1">
{item.name} {item?.name}
</div> </div>
</div> </div>
<div class=" text-xs text-gray-600 dark:text-gray-100 line-clamp-1"> <div class=" text-xs text-gray-600 dark:text-gray-100 line-clamp-1">
{item?.description} {item?.description}
</div> </div>
</div>
</button> </button>
<!-- <div slot="content" class=" pl-2 pt-1 flex flex-col gap-0.5">
{#if !item.legacy && (item?.files ?? []).length > 0}
{#each item?.files ?? [] as file, fileIdx}
<button
class=" px-3 py-1.5 rounded-xl w-full text-left flex justify-between items-center hover:bg-gray-50 dark:hover:bg-gray-850 dark:hover:text-gray-100 selected-command-option-button"
type="button"
on:click={() => {
console.log(file);
}}
on:mousemove={() => {
selectedIdx = idx;
}}
>
<div>
<div
class=" font-medium text-black dark:text-gray-100 flex items-center gap-1"
>
<div
class="bg-gray-500/20 text-gray-700 dark:text-gray-200 rounded uppercase text-xs font-bold px-1 flex-shrink-0"
>
File
</div>
<div class="line-clamp-1">
{file?.meta?.name}
</div>
</div>
<div class=" text-xs text-gray-600 dark:text-gray-100 line-clamp-1">
{$i18n.t('Updated')}
{dayjs(file.updated_at * 1000).fromNow()}
</div>
</div>
</button>
{/each}
{:else}
<div class=" text-gray-500 text-xs mt-1 mb-2">
{$i18n.t('No files found.')}
</div>
{/if}
</div> -->
{/each} {/each}
{#if prompt {#if prompt

View file

@ -58,7 +58,7 @@
onMount(async () => { onMount(async () => {
await tick(); await tick();
const chatInputElement = document.getElementById('chat-textarea'); const chatInputElement = document.getElementById('chat-input');
await tick(); await tick();
chatInputElement?.focus(); chatInputElement?.focus();
await tick(); await tick();
@ -68,15 +68,11 @@
{#if filteredItems.length > 0} {#if filteredItems.length > 0}
<div <div
id="commands-container" id="commands-container"
class="pl-2 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10" class="pl-3 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
> >
<div class="flex w-full dark:border dark:border-gray-850 rounded-lg"> <div class="flex w-full rounded-xl border border-gray-50 dark:border-gray-850">
<div class=" bg-gray-50 dark:bg-gray-850 w-10 rounded-l-lg text-center">
<div class=" text-lg font-medium mt-2">@</div>
</div>
<div <div
class="max-h-60 flex flex-col w-full rounded-r-lg bg-white dark:bg-gray-900 dark:text-gray-100" class="max-h-60 flex flex-col w-full rounded-xl bg-white dark:bg-gray-900 dark:text-gray-100"
> >
<div class="m-1 overflow-y-auto p-1 rounded-r-lg space-y-0.5 scrollbar-hidden"> <div class="m-1 overflow-y-auto p-1 rounded-r-lg space-y-0.5 scrollbar-hidden">
{#each filteredItems as model, modelIdx} {#each filteredItems as model, modelIdx}

View file

@ -1,5 +1,5 @@
<script lang="ts"> <script lang="ts">
import { prompts } from '$lib/stores'; import { prompts, user } from '$lib/stores';
import { import {
findWordIndices, findWordIndices,
getUserPosition, getUserPosition,
@ -78,6 +78,12 @@
text = text.replaceAll('{{USER_LOCATION}}', String(location)); text = text.replaceAll('{{USER_LOCATION}}', String(location));
} }
if (command.content.includes('{{USER_NAME}}')) {
console.log($user);
const name = $user.name || 'User';
text = text.replaceAll('{{USER_NAME}}', name);
}
if (command.content.includes('{{USER_LANGUAGE}}')) { if (command.content.includes('{{USER_LANGUAGE}}')) {
const language = localStorage.getItem('locale') || 'en-US'; const language = localStorage.getItem('locale') || 'en-US';
text = text.replaceAll('{{USER_LANGUAGE}}', language); text = text.replaceAll('{{USER_LANGUAGE}}', language);
@ -110,21 +116,20 @@
prompt = text; prompt = text;
const chatInputElement = document.getElementById('chat-textarea'); const chatInputContainerElement = document.getElementById('chat-input-container');
const chatInputElement = document.getElementById('chat-input');
await tick(); await tick();
if (chatInputContainerElement) {
chatInputElement.style.height = ''; chatInputContainerElement.style.height = '';
chatInputElement.style.height = Math.min(chatInputElement.scrollHeight, 200) + 'px'; chatInputContainerElement.style.height =
Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
chatInputElement?.focus(); }
await tick(); await tick();
if (chatInputElement) {
const words = findWordIndices(prompt); chatInputElement.focus();
if (words.length > 0) { chatInputElement.dispatchEvent(new Event('input'));
const word = words.at(0);
chatInputElement.setSelectionRange(word?.startIndex, word.endIndex + 1);
} }
}; };
</script> </script>
@ -132,17 +137,13 @@
{#if filteredPrompts.length > 0} {#if filteredPrompts.length > 0}
<div <div
id="commands-container" id="commands-container"
class="pl-2 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10" class="pl-3 pr-14 mb-3 text-left w-full absolute bottom-0 left-0 right-0 z-10"
> >
<div class="flex w-full dark:border dark:border-gray-850 rounded-lg"> <div class="flex w-full rounded-xl border border-gray-50 dark:border-gray-850">
<div class=" bg-gray-50 dark:bg-gray-850 w-10 rounded-l-lg text-center">
<div class=" text-lg font-medium mt-2">/</div>
</div>
<div <div
class="max-h-60 flex flex-col w-full rounded-r-lg bg-white dark:bg-gray-900 dark:text-gray-100" class="max-h-60 flex flex-col w-full rounded-xl bg-white dark:bg-gray-900 dark:text-gray-100"
> >
<div class="m-1 overflow-y-auto p-1 rounded-r-lg space-y-0.5 scrollbar-hidden"> <div class="m-1 overflow-y-auto p-1 space-y-0.5 scrollbar-hidden">
{#each filteredPrompts as prompt, promptIdx} {#each filteredPrompts as prompt, promptIdx}
<button <button
class=" px-3 py-1.5 rounded-xl w-full text-left {promptIdx === selectedPromptIdx class=" px-3 py-1.5 rounded-xl w-full text-left {promptIdx === selectedPromptIdx
@ -169,7 +170,7 @@
</div> </div>
<div <div
class=" px-2 pb-1 text-xs text-gray-600 dark:text-gray-100 bg-white dark:bg-gray-900 rounded-br-xl flex items-center space-x-1" class=" px-2 pt-0.5 pb-1 text-xs text-gray-600 dark:text-gray-100 bg-white dark:bg-gray-900 rounded-b-xl flex items-center space-x-1"
> >
<div> <div>
<svg <svg

View file

@ -61,9 +61,14 @@
<div <div
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl" class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
> >
<div class="flex-1 flex items-center gap-2"> <div class="flex-1">
<Tooltip
content={tools[toolId]?.description ?? ''}
placement="top-start"
className="flex flex-1 gap-2 items-center"
>
<WrenchSolid /> <WrenchSolid />
<Tooltip content={tools[toolId]?.description ?? ''} className="flex-1">
<div class=" line-clamp-1">{tools[toolId].name}</div> <div class=" line-clamp-1">{tools[toolId].name}</div>
</Tooltip> </Tooltip>
</div> </div>

View file

@ -1,6 +1,6 @@
<script lang="ts"> <script lang="ts">
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { createEventDispatcher, tick, getContext } from 'svelte'; import { createEventDispatcher, tick, getContext, onMount, onDestroy } from 'svelte';
import { config, settings } from '$lib/stores'; import { config, settings } from '$lib/stores';
import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils'; import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils';
@ -11,6 +11,7 @@
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
export let recording = false; export let recording = false;
export let className = ' p-2.5 w-full max-w-full';
let loading = false; let loading = false;
let confirmed = false; let confirmed = false;
@ -51,7 +52,7 @@
let audioChunks = []; let audioChunks = [];
const MIN_DECIBELS = -45; const MIN_DECIBELS = -45;
const VISUALIZER_BUFFER_LENGTH = 300; let VISUALIZER_BUFFER_LENGTH = 300;
let visualizerData = Array(VISUALIZER_BUFFER_LENGTH).fill(0); let visualizerData = Array(VISUALIZER_BUFFER_LENGTH).fill(0);
@ -141,8 +142,8 @@
}); });
if (res) { if (res) {
console.log(res.text); console.log(res);
dispatch('confirm', res.text); dispatch('confirm', res);
} }
}; };
@ -213,7 +214,7 @@
transcription = `${transcription}${transcript}`; transcription = `${transcription}${transcript}`;
await tick(); await tick();
document.getElementById('chat-textarea')?.focus(); document.getElementById('chat-input')?.focus();
// Restart the inactivity timeout // Restart the inactivity timeout
timeoutId = setTimeout(() => { timeoutId = setTimeout(() => {
@ -269,13 +270,48 @@
await mediaRecorder.stop(); await mediaRecorder.stop();
} }
clearInterval(durationCounter); clearInterval(durationCounter);
if (stream) {
const tracks = stream.getTracks();
tracks.forEach((track) => track.stop());
}
stream = null;
}; };
let resizeObserver;
let containerWidth;
let maxVisibleItems = 300;
$: maxVisibleItems = Math.floor(containerWidth / 5); // 2px width + 0.5px gap
onMount(() => {
// listen to width changes
resizeObserver = new ResizeObserver(() => {
VISUALIZER_BUFFER_LENGTH = Math.floor(window.innerWidth / 4);
if (visualizerData.length > VISUALIZER_BUFFER_LENGTH) {
visualizerData = visualizerData.slice(visualizerData.length - VISUALIZER_BUFFER_LENGTH);
} else {
visualizerData = Array(VISUALIZER_BUFFER_LENGTH - visualizerData.length)
.fill(0)
.concat(visualizerData);
}
});
resizeObserver.observe(document.body);
});
onDestroy(() => {
// remove resize observer
resizeObserver.disconnect();
});
</script> </script>
<div <div
bind:clientWidth={containerWidth}
class="{loading class="{loading
? ' bg-gray-100/50 dark:bg-gray-850/50' ? ' bg-gray-100/50 dark:bg-gray-850/50'
: 'bg-indigo-300/10 dark:bg-indigo-500/10 '} rounded-full flex p-2.5" : 'bg-indigo-300/10 dark:bg-indigo-500/10 '} rounded-full flex justify-between {className}"
> >
<div class="flex items-center mr-1"> <div class="flex items-center mr-1">
<button <button
@ -310,10 +346,13 @@
class="flex flex-1 self-center items-center justify-between ml-2 mx-1 overflow-hidden h-6" class="flex flex-1 self-center items-center justify-between ml-2 mx-1 overflow-hidden h-6"
dir="rtl" dir="rtl"
> >
<div class="flex-1 flex items-center gap-0.5 h-6">
{#each visualizerData.slice().reverse() as rms}
<div <div
class="w-[2px] class="flex items-center gap-0.5 h-6 w-full max-w-full overflow-hidden overflow-x-hidden flex-wrap"
>
{#each visualizerData.slice().reverse() as rms}
<div class="flex items-center h-full">
<div
class="w-[2px] flex-shrink-0
{loading {loading
? ' bg-gray-500 dark:bg-gray-400 ' ? ' bg-gray-500 dark:bg-gray-400 '
@ -322,10 +361,12 @@
inline-block h-full" inline-block h-full"
style="height: {Math.min(100, Math.max(14, rms * 100))}%;" style="height: {Math.min(100, Math.max(14, rms * 100))}%;"
/> />
</div>
{/each} {/each}
</div> </div>
</div> </div>
<div class="flex">
<div class=" mx-1.5 pr-1 flex justify-center items-center"> <div class=" mx-1.5 pr-1 flex justify-center items-center">
<div <div
class="text-sm class="text-sm
@ -338,7 +379,7 @@
</div> </div>
</div> </div>
<div class="flex items-center mr-1"> <div class="flex items-center">
{#if loading} {#if loading}
<div class=" text-gray-500 rounded-full cursor-not-allowed"> <div class=" text-gray-500 rounded-full cursor-not-allowed">
<svg <svg
@ -452,6 +493,7 @@
{/if} {/if}
</div> </div>
</div> </div>
</div>
<style> <style>
.visualizer { .visualizer {

View file

@ -330,20 +330,14 @@
await tick(); await tick();
const chatInputElement = document.getElementById('chat-textarea'); const chatInputContainerElement = document.getElementById('chat-input-container');
if (chatInputElement) { if (chatInputContainerElement) {
prompt = p; prompt = p;
chatInputElement.style.height = ''; chatInputContainerElement.style.height = '';
chatInputElement.style.height = Math.min(chatInputElement.scrollHeight, 200) + 'px'; chatInputContainerElement.style.height =
chatInputElement.focus(); Math.min(chatInputContainerElement.scrollHeight, 200) + 'px';
chatInputContainerElement.focus();
const words = findWordIndices(prompt);
if (words.length > 0) {
const word = words.at(0);
chatInputElement.setSelectionRange(word?.startIndex, word.endIndex + 1);
}
} }
await tick(); await tick();

View file

@ -1,13 +1,51 @@
<script lang="ts"> <script lang="ts">
import { getContext } from 'svelte';
import CitationsModal from './CitationsModal.svelte'; import CitationsModal from './CitationsModal.svelte';
import Collapsible from '$lib/components/common/Collapsible.svelte';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
const i18n = getContext('i18n');
export let citations = []; export let citations = [];
let _citations = []; let _citations = [];
let showPercentage = false;
let showRelevance = true;
$: _citations = citations.reduce((acc, citation) => { let showCitationModal = false;
let selectedCitation: any = null;
let isCollapsibleOpen = false;
function calculateShowRelevance(citations: any[]) {
const distances = citations.flatMap((citation) => citation.distances ?? []);
const inRange = distances.filter((d) => d !== undefined && d >= -1 && d <= 1).length;
const outOfRange = distances.filter((d) => d !== undefined && (d < -1 || d > 1)).length;
if (distances.length === 0) {
return false;
}
if (
(inRange === distances.length - 1 && outOfRange === 1) ||
(outOfRange === distances.length - 1 && inRange === 1)
) {
return false;
}
return true;
}
function shouldShowPercentage(citations: any[]) {
const distances = citations.flatMap((citation) => citation.distances ?? []);
return distances.every((d) => d !== undefined && d >= -1 && d <= 1);
}
$: {
_citations = citations.reduce((acc, citation) => {
citation.document.forEach((document, index) => { citation.document.forEach((document, index) => {
const metadata = citation.metadata?.[index]; const metadata = citation.metadata?.[index];
const distance = citation.distances?.[index];
const id = metadata?.source ?? 'N/A'; const id = metadata?.source ?? 'N/A';
let source = citation?.source; let source = citation?.source;
@ -15,9 +53,8 @@
source = { ...source, name: metadata.name }; source = { ...source, name: metadata.name };
} }
// Check if ID looks like a URL
if (id.startsWith('http://') || id.startsWith('https://')) { if (id.startsWith('http://') || id.startsWith('https://')) {
source = { name: id }; source = { name: id, ...source, url: id };
} }
const existingSource = acc.find((item) => item.id === id); const existingSource = acc.find((item) => item.id === id);
@ -25,38 +62,149 @@
if (existingSource) { if (existingSource) {
existingSource.document.push(document); existingSource.document.push(document);
existingSource.metadata.push(metadata); existingSource.metadata.push(metadata);
if (distance !== undefined) existingSource.distances.push(distance);
} else { } else {
acc.push({ acc.push({
id: id, id: id,
source: source, source: source,
document: [document], document: [document],
metadata: metadata ? [metadata] : [] metadata: metadata ? [metadata] : [],
distances: distance !== undefined ? [distance] : undefined
}); });
} }
}); });
return acc; return acc;
}, []); }, []);
let showCitationModal = false; showRelevance = calculateShowRelevance(_citations);
let selectedCitation = null; showPercentage = shouldShowPercentage(_citations);
}
</script> </script>
<CitationsModal bind:show={showCitationModal} citation={selectedCitation} /> <CitationsModal
bind:show={showCitationModal}
citation={selectedCitation}
{showPercentage}
{showRelevance}
/>
{#if _citations.length > 0} {#if _citations.length > 0}
<div class="mt-1 mb-2 w-full flex gap-1 items-center flex-wrap"> <div class="mt-1 mb-2 w-full flex gap-1 items-center flex-wrap">
{#if _citations.length <= 3}
{#each _citations as citation, idx} {#each _citations as citation, idx}
<div class="flex gap-1 text-xs font-semibold"> <div class="flex gap-1 text-xs font-semibold">
<button <button
class="flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl max-w-96" class="no-toggle flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl max-w-96"
on:click={() => { on:click={() => {
showCitationModal = true; showCitationModal = true;
selectedCitation = citation; selectedCitation = citation;
}} }}
> >
{#if _citations.every((c) => c.distances !== undefined)}
<div class="bg-white dark:bg-gray-700 rounded-full size-4"> <div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1} {idx + 1}
</div> </div>
{/if}
<div class="flex-1 mx-2 line-clamp-1 truncate">
{citation.source.name}
</div>
</button>
</div>
{/each}
{:else}
<Collapsible bind:open={isCollapsibleOpen} className="w-full">
<div
class="flex items-center gap-1 text-gray-500 hover:text-gray-600 dark:hover:text-gray-400 transition cursor-pointer"
>
<div class="flex-grow flex items-center gap-1 overflow-hidden">
<span class="whitespace-nowrap hidden sm:inline">{$i18n.t('References from')}</span>
<div class="flex items-center">
{#if _citations.length > 1 && _citations
.slice(0, 2)
.reduce((acc, citation) => acc + citation.source.name.length, 0) <= 50}
{#each _citations.slice(0, 2) as citation, idx}
<div class="flex items-center">
<button
class="no-toggle flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl max-w-96 text-xs font-semibold"
on:click={() => {
showCitationModal = true;
selectedCitation = citation;
}}
>
{#if _citations.every((c) => c.distances !== undefined)}
<div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1}
</div>
{/if}
<div class="flex-1 mx-2 line-clamp-1">
{citation.source.name}
</div>
</button>
{#if idx === 0}<span class="mr-1">,</span>
{/if}
</div>
{/each}
{:else}
{#each _citations.slice(0, 1) as citation, idx}
<div class="flex items-center">
<button
class="no-toggle flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl max-w-96 text-xs font-semibold"
on:click={() => {
showCitationModal = true;
selectedCitation = citation;
}}
>
{#if _citations.every((c) => c.distances !== undefined)}
<div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1}
</div>
{/if}
<div class="flex-1 mx-2 line-clamp-1">
{citation.source.name}
</div>
</button>
</div>
{/each}
{/if}
</div>
<div class="flex items-center gap-1 whitespace-nowrap">
<span class="hidden sm:inline">{$i18n.t('and')}</span>
<span class="text-gray-600 dark:text-gray-400">
{_citations.length -
(_citations.length > 1 &&
_citations
.slice(0, 2)
.reduce((acc, citation) => acc + citation.source.name.length, 0) <= 50
? 2
: 1)}
</span>
<span>{$i18n.t('more')}</span>
</div>
</div>
<div class="flex-shrink-0">
{#if isCollapsibleOpen}
<ChevronUp strokeWidth="3.5" className="size-3.5" />
{:else}
<ChevronDown strokeWidth="3.5" className="size-3.5" />
{/if}
</div>
</div>
<div slot="content" class="mt-2">
<div class="flex flex-wrap gap-2">
{#each _citations as citation, idx}
<div class="flex gap-1 text-xs font-semibold">
<button
class="no-toggle flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl max-w-96"
on:click={() => {
showCitationModal = true;
selectedCitation = citation;
}}
>
{#if _citations.every((c) => c.distances !== undefined)}
<div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1}
</div>
{/if}
<div class="flex-1 mx-2 line-clamp-1"> <div class="flex-1 mx-2 line-clamp-1">
{citation.source.name} {citation.source.name}
</div> </div>
@ -64,4 +212,8 @@
</div> </div>
{/each} {/each}
</div> </div>
</div>
</Collapsible>
{/if}
</div>
{/if} {/if}

Some files were not shown because too many files have changed in this diff Show more