diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3294d80949..54053bbd82 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,65 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.6.16] - 2025-07-14
+
+### Added
+
+- 🗂️ **Folders as Projects**: Organize your workflow with folder-based projects—set folder-level system prompts and associate custom knowledge, bringing seamless, context-rich management to teams and users handling multiple initiatives or clients.
+- 📁 **Instant Folder-Based Chat Creation**: Start a new chat directly from any folder; just click and your new conversation is automatically embedded in the right project context—no more manual dragging or setup, saving time and eliminating mistakes.
+- 🧩 **Prompt Variables with Automatic Input Modal**: Prompts containing variables now display a clean, auto-generated input modal that **autofocuses on the first field** for instant value entry—just select the prompt and fill in exactly what’s needed, reducing friction and guesswork.
+- 🔡 **Variable Input Typing in Prompts**: Define input types for prompt variables (e.g., text, textarea, number, select, color, date, map and more), giving everyone a clearer and more precise prompt-building experience for advanced automation or workflows.
+- 🚀 **Base Model List Caching**: Cache your base model list to speed up model selection and reduce repeated API calls; toggle this in Admin Settings > Connections for responsive model management even in large or multi-provider setups.
+- ⏱️ **Configurable Model List Cache TTL**: Take control over model list caching with the new MODEL_LIST_CACHE_TTL environment variable. Set a custom cache duration in seconds to balance performance and freshness, reducing API requests in stable environments or ensuring rapid updates when models change frequently.
+- 🔖 **Reference Notes as Knowledge or in Chats**: Use any note as knowledge for a model or folder, or reference it directly from chat—integrate living documentation into your Retrieval Augmented Generation workflows or discussions, bridging knowledge and action.
+- 📝 **Chat Directly with Notes (Experimental)**: Ask questions about any note, and directly edit or update notes from within a chat—unlock direct AI-powered brainstorming, summarization, and cleanup, like having your own collaborative AI canvas.
+- 🤝 **Collaborative Notes with Multi-User Editing**: Share notes with others and collaborate live—multiple users can edit a note in real-time, boosting cooperative knowledge building and workflow documentation.
+- 🛡️ **Collaborative Note Permissions**: Control who can view or edit each note with robust sharing permissions, ensuring privacy or collaboration per your organizational needs.
+- 🔗 **Copy Link to Notes**: Quickly copy and share direct links to notes for easier knowledge transfer within your team or external collaborators.
+- 📋 **Task List Support in Notes**: Add, organize, and manage checklists or tasks inside your notes—plan projects, track to-dos, and keep everything actionable in a single space.
+- 🧠 **AI-Generated Note Titles**: Instantly generate relevant and concise titles for your notes using AI—keep your knowledge library organized without tedious manual editing.
+- 🔄 **Full Undo/Redo Support in Notes**: Effortlessly undo or redo your latest note changes—never fear mistakes or accidental edits while collaborating or writing.
+- 📝 **Enhanced Note Word/Character Counter**: Always know the size of your notes with built-in counters, making it easier to adhere to length guidelines for shared or published content.
+- 🖊️ **Floating & Bubble Formatting Menus in Note Editor**: Access text formatting tools through both a floating menu and an intuitive bubble menu directly in the note editor—making rich text editing faster, more discoverable, and easier than ever.
+- ✍️ **Rich Text Prompt Insertion**: A new setting allows prompts to be inserted directly into the chat box as fully-formatted rich text, preserving Markdown elements like headings, lists, and bold text for a more intuitive and visually consistent editing experience.
+- 🌐 **Configurable Database URL**: WebUI now supports more flexible database configuration via new environment variables—making deployment and scaling simpler across various infrastructure setups.
+- 🎛️ **Completely Frontend-Handled File Upload in Temporary Chats**: When using temporary chats, file extraction now occurs fully in your browser with zero files sent to the backend, further strengthening privacy and giving you instant feedback.
+- 🔄 **Enhanced Banner and Chat Command Visibility**: Banner handling and command feedback in chat are now clearer and more contextually visible, making alerts, suggestions, and automation easier to spot and interact with for all users.
+- 📱 **Mobile Experience Polished**: The "new chat" button is back in mobile, plus core navigation and input controls have been smoothed out for better usability on phones and tablets.
+- 📄 **OpenDocument Text (.odt) Support**: Seamlessly upload and process .odt files from open-source office suites like LibreOffice and OpenOffice, expanding your ability to build knowledge from a wider range of document formats.
+- 📑 **Enhanced Markdown Document Splitting**: Improve knowledge retrieval from Markdown files with a new header-aware splitting strategy. This method intelligently chunks documents based on their header structure, preserving the original context and hierarchy for more accurate and relevant RAG results.
+- 📚 **Full Context Mode for Knowledge Bases**: When adding a knowledge base to a folder or custom model, you can now toggle full context mode for the entire knowledge base. This bypasses the usual chunking and retrieval process, making it perfect for leaner knowledge bases.
+- 🕰️ **Configurable OAuth Timeout**: Enhance login reliability by setting a custom timeout (OAUTH_TIMEOUT) for all OAuth providers (Google, Microsoft, GitHub, OIDC), preventing authentication failures on slow or restricted networks.
+- 🎨 **Accessibility & High-Contrast Theme Enhancements**: Major accessibility overhaul with significant updates to the high-contrast theme. Improved focus visibility, ARIA labels, and semantic HTML ensure core components like the chat interface and model selector are fully compliant and readable for visually impaired users.
+- ↕️ **Resizable System Prompt Fields**: Conveniently resize system prompt input fields to comfortably view and edit lengthy or complex instructions, improving the user experience for advanced model configuration.
+- 🔧 **Granular Update Check Control**: Gain finer control over outbound connections with the new ENABLE_VERSION_UPDATE_CHECK flag. This allows administrators to disable version update checks independently of the full OFFLINE_MODE, perfect for environments with restricted internet access that still need to download embedding models.
+- 🗃️ **Configurable Qdrant Collection Prefix**: Enhance scalability by setting a custom QDRANT_COLLECTION_PREFIX. This allows multiple Open WebUI instances to share a single Qdrant cluster safely, ensuring complete data isolation between separate deployments without conflicts.
+- ⚙️ **Improved Default Database Performance**: Enhanced out-of-the-box performance by setting smarter database connection pooling defaults, reducing API response times for users on non-SQLite databases without requiring manual configuration.
+- 🔧 **Configurable Redis Key Prefix**: Added support for the REDIS_KEY_PREFIX environment variable, allowing multiple Open WebUI instances to share a Redis cluster with isolated key namespaces for improved multi-tenancy.
+- ➡️ **Forward User Context to Reranker**: For advanced RAG integrations, user information (ID, name, email, role) can now be forwarded as HTTP headers to external reranking services, enabling personalized results or per-user access control.
+- ⚙️ **PGVector Connection Pooling**: Enhance performance and stability for PGVector-based RAG by enabling and configuring the database connection pool. New environment variables allow fine-tuning of pool size, timeout, and overflow settings to handle high-concurrency workloads efficiently.
+- ⚙️ **General Backend Refactoring**: Extensive refactoring delivers a faster, more reliable, and robust backend experience—improving chat speed, model management, and day-to-day reliability.
+- 🌍 **Expanded & Improved Translations**: Enjoy a more accessible and intuitive experience thanks to comprehensive updates and enhancements for Chinese (Simplified and Traditional), German, French, Catalan, Irish, and Spanish translations throughout the interface.
+
+### Fixed
+
+- 🛠️ **Rich Text Input Stability and Performance**: Multiple improvements ensure faster, cleaner text editing and rendering with reduced glitches—especially supporting links, color picking, checkbox controls, and code blocks in notes and chats.
+- 📷 **Seamless iPhone Image Uploads**: Effortlessly upload photos from iPhones and other devices using HEIC format—images are now correctly recognized and processed, eliminating compatibility issues.
+- 🔄 **Audio MIME Type Registration**: Issues with audio file content types have been resolved, guaranteeing smoother, error-free uploads and playback for transcription or note attachments.
+- 🖍️ **Input Commands Now Always Visible**: Input commands (like prompts or knowledge) dynamically adjust their height on small screens, ensuring nothing is cut off and every tool remains easily accessible.
+- 🛑 **Tool Result Rendering**: Fixed display problems with tool results, providing fast, clear feedback when using external or internal tools.
+- 🗂️ **Table Alignment in Markdown**: Markdown tables are now rendered and aligned as expected, keeping reports and documentation readable.
+- 🖼️ **Thread Image Handling**: Fixed an issue where messages containing only images in threads weren’t displayed correctly.
+- 🗝️ **Note Access Control Security**: Tightened access control logic for notes to guarantee that shared or collaborative notes respect all user permissions and privacy safeguards.
+- 🧾 **Ollama API Compatibility**: Fixed model parameter naming in the API to ensure uninterrupted compatibility for all Ollama endpoints.
+- 🛠️ **Detection for 'text/html' Files**: Files loaded with docling/tika are now reliably detected as the correct type, improving knowledge ingestion and document parsing.
+- 🔐 **OAuth Login Stability**: Resolved a critical OAuth bug that caused login failures on subsequent attempts after logging out. The user session is now completely cleared on logout, ensuring reliable and secure authentication across all supported providers (Google, Microsoft, GitHub, OIDC).
+- 🚪 **OAuth Logout and Redirect Reliability**: The OAuth logout process has been made more robust. Logout requests now correctly use proxy environment variables, ensuring they succeed in corporate networks. Additionally, the custom WEBUI_AUTH_SIGNOUT_REDIRECT_URL is now properly respected for all OAuth/OIDC configurations, ensuring a seamless sign-out experience.
+- 📜 **Banner Newline Rendering**: Banners now correctly render newline characters, ensuring that multi-line announcements and messages are displayed with their intended formatting.
+- ℹ️ **Consistent Model Description Rendering**: Model descriptions now render Markdown correctly in the main chat interface, matching the formatting seen in the model selection dropdown for a consistent user experience.
+- 🔄 **Offline Mode Update Check Display**: Corrected a UI bug where the "Checking for Updates..." message would display indefinitely when the application was set to offline mode.
+- 🛠️ **Tool Result Encoding**: Fixed a bug where tool calls returning non-ASCII characters would fail, ensuring robust handling of international text and special characters in tool outputs.
+
## [0.6.15] - 2025-06-16
### Added
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 898ac1b594..46d3b719a6 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -13,12 +13,15 @@ from urllib.parse import urlparse
import requests
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
+from authlib.integrations.starlette_client import OAuth
+
from open_webui.env import (
DATA_DIR,
DATABASE_URL,
ENV,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
FRONTEND_BUILD_DIR,
@@ -211,11 +214,16 @@ class PersistentConfig(Generic[T]):
class AppConfig:
_state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None
+ _redis_key_prefix: str
def __init__(
- self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = []
+ self,
+ redis_url: Optional[str] = None,
+ redis_sentinels: Optional[list] = [],
+ redis_key_prefix: str = "open-webui",
):
super().__setattr__("_state", {})
+ super().__setattr__("_redis_key_prefix", redis_key_prefix)
if redis_url:
super().__setattr__(
"_redis",
@@ -230,7 +238,7 @@ class AppConfig:
self._state[key].save()
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
self._redis.set(redis_key, json.dumps(self._state[key].value))
def __getattr__(self, key):
@@ -239,7 +247,7 @@ class AppConfig:
# If Redis is available, check for an updated value
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
redis_value = self._redis.get(redis_key)
if redis_value is not None:
@@ -431,6 +439,12 @@ OAUTH_SCOPES = PersistentConfig(
os.environ.get("OAUTH_SCOPES", "openid email profile"),
)
+OAUTH_TIMEOUT = PersistentConfig(
+ "OAUTH_TIMEOUT",
+ "oauth.oidc.oauth_timeout",
+ os.environ.get("OAUTH_TIMEOUT", ""),
+)
+
OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig(
"OAUTH_CODE_CHALLENGE_METHOD",
"oauth.oidc.code_challenge_method",
@@ -534,13 +548,20 @@ def load_oauth_providers():
OAUTH_PROVIDERS.clear()
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
- def google_oauth_register(client):
+ def google_oauth_register(client: OAuth):
client.register(
name="google",
client_id=GOOGLE_CLIENT_ID.value,
client_secret=GOOGLE_CLIENT_SECRET.value,
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
- client_kwargs={"scope": GOOGLE_OAUTH_SCOPE.value},
+ client_kwargs={
+ "scope": GOOGLE_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GOOGLE_REDIRECT_URI.value,
)
@@ -555,7 +576,7 @@ def load_oauth_providers():
and MICROSOFT_CLIENT_TENANT_ID.value
):
- def microsoft_oauth_register(client):
+ def microsoft_oauth_register(client: OAuth):
client.register(
name="microsoft",
client_id=MICROSOFT_CLIENT_ID.value,
@@ -563,6 +584,11 @@ def load_oauth_providers():
server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
client_kwargs={
"scope": MICROSOFT_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
},
redirect_uri=MICROSOFT_REDIRECT_URI.value,
)
@@ -575,7 +601,7 @@ def load_oauth_providers():
if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value:
- def github_oauth_register(client):
+ def github_oauth_register(client: OAuth):
client.register(
name="github",
client_id=GITHUB_CLIENT_ID.value,
@@ -584,7 +610,14 @@ def load_oauth_providers():
authorize_url="https://github.com/login/oauth/authorize",
api_base_url="https://api.github.com",
userinfo_endpoint="https://api.github.com/user",
- client_kwargs={"scope": GITHUB_CLIENT_SCOPE.value},
+ client_kwargs={
+ "scope": GITHUB_CLIENT_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value,
)
@@ -600,9 +633,12 @@ def load_oauth_providers():
and OPENID_PROVIDER_URL.value
):
- def oidc_oauth_register(client):
+ def oidc_oauth_register(client: OAuth):
client_kwargs = {
"scope": OAUTH_SCOPES.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}
+ ),
}
if (
@@ -895,6 +931,18 @@ except Exception:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
+
+####################################
+# MODELS
+####################################
+
+ENABLE_BASE_MODELS_CACHE = PersistentConfig(
+ "ENABLE_BASE_MODELS_CACHE",
+ "models.base_models_cache",
+ os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true",
+)
+
+
####################################
# TOOL_SERVERS
####################################
@@ -1794,11 +1842,12 @@ MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128"))
QDRANT_URI = os.environ.get("QDRANT_URI", None)
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
-QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true"
+QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
ENABLE_QDRANT_MULTITENANCY_MODE = (
- os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "false").lower() == "true"
+ os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
)
+QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
@@ -1837,6 +1886,45 @@ if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
"PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
)
+
+PGVECTOR_POOL_SIZE = os.environ.get("PGVECTOR_POOL_SIZE", None)
+
+if PGVECTOR_POOL_SIZE != None:
+ try:
+ PGVECTOR_POOL_SIZE = int(PGVECTOR_POOL_SIZE)
+ except Exception:
+ PGVECTOR_POOL_SIZE = None
+
+PGVECTOR_POOL_MAX_OVERFLOW = os.environ.get("PGVECTOR_POOL_MAX_OVERFLOW", 0)
+
+if PGVECTOR_POOL_MAX_OVERFLOW == "":
+ PGVECTOR_POOL_MAX_OVERFLOW = 0
+else:
+ try:
+ PGVECTOR_POOL_MAX_OVERFLOW = int(PGVECTOR_POOL_MAX_OVERFLOW)
+ except Exception:
+ PGVECTOR_POOL_MAX_OVERFLOW = 0
+
+PGVECTOR_POOL_TIMEOUT = os.environ.get("PGVECTOR_POOL_TIMEOUT", 30)
+
+if PGVECTOR_POOL_TIMEOUT == "":
+ PGVECTOR_POOL_TIMEOUT = 30
+else:
+ try:
+ PGVECTOR_POOL_TIMEOUT = int(PGVECTOR_POOL_TIMEOUT)
+ except Exception:
+ PGVECTOR_POOL_TIMEOUT = 30
+
+PGVECTOR_POOL_RECYCLE = os.environ.get("PGVECTOR_POOL_RECYCLE", 3600)
+
+if PGVECTOR_POOL_RECYCLE == "":
+ PGVECTOR_POOL_RECYCLE = 3600
+else:
+ try:
+ PGVECTOR_POOL_RECYCLE = int(PGVECTOR_POOL_RECYCLE)
+ except Exception:
+ PGVECTOR_POOL_RECYCLE = 3600
+
# Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index 0f7b5611f5..4db919121a 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -199,6 +199,7 @@ CHANGELOG = changelog_json
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
+
####################################
# ENABLE_FORWARD_USER_INFO_HEADERS
####################################
@@ -266,21 +267,43 @@ else:
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
+DATABASE_TYPE = os.environ.get("DATABASE_TYPE")
+DATABASE_USER = os.environ.get("DATABASE_USER")
+DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD")
+
+DATABASE_CRED = ""
+if DATABASE_USER:
+ DATABASE_CRED += f"{DATABASE_USER}"
+if DATABASE_PASSWORD:
+ DATABASE_CRED += f":{DATABASE_PASSWORD}"
+if DATABASE_CRED:
+ DATABASE_CRED += "@"
+
+
+DB_VARS = {
+ "db_type": DATABASE_TYPE,
+ "db_cred": DATABASE_CRED,
+ "db_host": os.environ.get("DATABASE_HOST"),
+ "db_port": os.environ.get("DATABASE_PORT"),
+ "db_name": os.environ.get("DATABASE_NAME"),
+}
+
+if all(DB_VARS.values()):
+ DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
+
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
-DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
+DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None)
-if DATABASE_POOL_SIZE == "":
- DATABASE_POOL_SIZE = 0
-else:
+if DATABASE_POOL_SIZE != None:
try:
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
except Exception:
- DATABASE_POOL_SIZE = 0
+ DATABASE_POOL_SIZE = None
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
@@ -325,6 +348,7 @@ ENABLE_REALTIME_CHAT_SAVE = (
####################################
REDIS_URL = os.environ.get("REDIS_URL", "")
+REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
@@ -396,10 +420,33 @@ WEBUI_AUTH_COOKIE_SECURE = (
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
+ENABLE_COMPRESSION_MIDDLEWARE = (
+ os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
+)
+
+####################################
+# MODELS
+####################################
+
+MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
+if MODELS_CACHE_TTL == "":
+ MODELS_CACHE_TTL = None
+else:
+ try:
+ MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
+ except Exception:
+ MODELS_CACHE_TTL = 1
+
+
+####################################
+# WEBSOCKET SUPPORT
+####################################
+
ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
)
+
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
@@ -506,11 +553,14 @@ else:
# OFFLINE_MODE
####################################
+ENABLE_VERSION_UPDATE_CHECK = (
+ os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true"
+)
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
-
+ ENABLE_VERSION_UPDATE_CHECK = False
####################################
# AUDIT LOGGING
@@ -519,6 +569,14 @@ if OFFLINE_MODE:
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
+
+# Comma separated list of logger names to use for audit logging
+# Default is "uvicorn.access" which is the access log for Uvicorn
+# You can add more logger names to this list if you want to capture more logs
+AUDIT_UVICORN_LOGGER_NAMES = os.getenv(
+ "AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access"
+).split(",")
+
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
@@ -543,6 +601,9 @@ ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() ==
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
+OTEL_EXPORTER_OTLP_INSECURE = (
+ os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
+)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
@@ -550,6 +611,14 @@ OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
+OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
+OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
+
+
+OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
+ "OTEL_OTLP_SPAN_EXPORTER", "grpc"
+).lower() # grpc or http
+
####################################
# TOOLS/FUNCTIONS PIP OPTIONS
diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py
index 840f571cc9..e1ffc1eb27 100644
--- a/backend/open_webui/internal/db.py
+++ b/backend/open_webui/internal/db.py
@@ -62,6 +62,9 @@ def handle_peewee_migration(DATABASE_URL):
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
+ log.warning(
+ "Hint: If your database password contains special characters, you may need to URL-encode it."
+ )
raise
finally:
# Properly closing the database connection
@@ -81,20 +84,23 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
else:
- if DATABASE_POOL_SIZE > 0:
- engine = create_engine(
- SQLALCHEMY_DATABASE_URL,
- pool_size=DATABASE_POOL_SIZE,
- max_overflow=DATABASE_POOL_MAX_OVERFLOW,
- pool_timeout=DATABASE_POOL_TIMEOUT,
- pool_recycle=DATABASE_POOL_RECYCLE,
- pool_pre_ping=True,
- poolclass=QueuePool,
- )
+ if isinstance(DATABASE_POOL_SIZE, int):
+ if DATABASE_POOL_SIZE > 0:
+ engine = create_engine(
+ SQLALCHEMY_DATABASE_URL,
+ pool_size=DATABASE_POOL_SIZE,
+ max_overflow=DATABASE_POOL_MAX_OVERFLOW,
+ pool_timeout=DATABASE_POOL_TIMEOUT,
+ pool_recycle=DATABASE_POOL_RECYCLE,
+ pool_pre_ping=True,
+ poolclass=QueuePool,
+ )
+ else:
+ engine = create_engine(
+ SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
+ )
else:
- engine = create_engine(
- SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
- )
+ engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 544756a6e8..595d551d75 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -36,7 +36,6 @@ from fastapi import (
applications,
BackgroundTasks,
)
-
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
+from starlette.datastructures import Headers
from open_webui.utils import logger
@@ -89,6 +89,7 @@ from open_webui.routers import (
from open_webui.routers.retrieval import (
get_embedding_function,
+ get_reranking_function,
get_ef,
get_rf,
)
@@ -116,6 +117,8 @@ from open_webui.config import (
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
+ # Model list
+ ENABLE_BASE_MODELS_CACHE,
# Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE,
# Tool Server Configs
@@ -396,6 +399,7 @@ from open_webui.env import (
AUDIT_LOG_LEVEL,
CHANGELOG,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
GLOBAL_LOG_LEVEL,
@@ -411,10 +415,11 @@ from open_webui.env import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
+ ENABLE_COMPRESSION_MIDDLEWARE,
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
- OFFLINE_MODE,
+ ENABLE_VERSION_UPDATE_CHECK,
ENABLE_OTEL,
EXTERNAL_PWA_MANIFEST_URL,
AIOHTTP_CLIENT_SESSION_SSL,
@@ -449,7 +454,7 @@ from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import (
redis_task_command_listener,
- list_task_ids_by_chat_id,
+ list_task_ids_by_item_id,
stop_task,
list_tasks,
) # Import from tasks.py
@@ -533,6 +538,27 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup())
+ if app.state.config.ENABLE_BASE_MODELS_CACHE:
+ await get_all_models(
+ Request(
+ # Creating a mock request object to pass to get_all_models
+ {
+ "type": "http",
+ "asgi.version": "3.0",
+ "asgi.spec_version": "2.0",
+ "method": "GET",
+ "path": "/internal",
+ "query_string": b"",
+ "headers": Headers({}).raw,
+ "client": ("127.0.0.1", 12345),
+ "server": ("127.0.0.1", 80),
+ "scheme": "http",
+ "app": app,
+ }
+ ),
+ None,
+ )
+
yield
if hasattr(app.state, "redis_task_command_listener"):
@@ -553,6 +579,7 @@ app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
+ redis_key_prefix=REDIS_KEY_PREFIX,
)
app.state.redis = None
@@ -615,6 +642,15 @@ app.state.TOOL_SERVERS = []
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
+########################################
+#
+# MODELS
+#
+########################################
+
+app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE
+app.state.BASE_MODELS = []
+
########################################
#
# WEBUI
@@ -843,6 +879,7 @@ app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
app.state.EMBEDDING_FUNCTION = None
+app.state.RERANKING_FUNCTION = None
app.state.ef = None
app.state.rf = None
@@ -871,8 +908,8 @@ except Exception as e:
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
- app.state.ef,
- (
+ embedding_function=app.state.ef,
+ url=(
app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else (
@@ -881,7 +918,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
- (
+ key=(
app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else (
@@ -890,7 +927,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
else app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
- app.state.config.RAG_EMBEDDING_BATCH_SIZE,
+ embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=(
app.state.config.RAG_AZURE_OPENAI_API_VERSION
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
@@ -898,6 +935,12 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
),
)
+app.state.RERANKING_FUNCTION = get_reranking_function(
+ app.state.config.RAG_RERANKING_ENGINE,
+ app.state.config.RAG_RERANKING_MODEL,
+ reranking_function=app.state.rf,
+)
+
########################################
#
# CODE EXECUTION
@@ -1072,7 +1115,9 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Add the middleware to the app
-app.add_middleware(CompressMiddleware)
+if ENABLE_COMPRESSION_MIDDLEWARE:
+ app.add_middleware(CompressMiddleware)
+
app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
@@ -1188,7 +1233,9 @@ if audit_level != AuditLevel.NONE:
@app.get("/api/models")
-async def get_models(request: Request, user=Depends(get_verified_user)):
+async def get_models(
+ request: Request, refresh: bool = False, user=Depends(get_verified_user)
+):
def get_filtered_models(models, user):
filtered_models = []
for model in models:
@@ -1212,7 +1259,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
- all_models = await get_all_models(request, user=user)
+ all_models = await get_all_models(request, refresh=refresh, user=user)
models = []
for model in all_models:
@@ -1447,7 +1494,7 @@ async def stop_task_endpoint(
request: Request, task_id: str, user=Depends(get_verified_user)
):
try:
- result = await stop_task(request, task_id)
+ result = await stop_task(request.app.state.redis, task_id)
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@@ -1455,7 +1502,7 @@ async def stop_task_endpoint(
@app.get("/api/tasks")
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
- return {"tasks": await list_tasks(request)}
+ return {"tasks": await list_tasks(request.app.state.redis)}
@app.get("/api/tasks/chat/{chat_id}")
@@ -1466,9 +1513,9 @@ async def list_tasks_by_chat_id_endpoint(
if chat is None or chat.user_id != user.id:
return {"task_ids": []}
- task_ids = await list_task_ids_by_chat_id(request, chat_id)
+ task_ids = await list_task_ids_by_item_id(request.app.state.redis, chat_id)
- print(f"Task IDs for chat {chat_id}: {task_ids}")
+ log.debug(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}
@@ -1521,6 +1568,7 @@ async def get_app_config(request: Request):
"enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
+ "enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK,
**(
{
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
@@ -1594,7 +1642,19 @@ async def get_app_config(request: Request):
),
}
if user is not None
- else {}
+ else {
+ **(
+ {
+ "metadata": {
+ "login_footer": app.state.LICENSE_METADATA.get(
+ "login_footer", ""
+ )
+ }
+ }
+ if app.state.LICENSE_METADATA
+ else {}
+ )
+ }
),
}
@@ -1626,9 +1686,9 @@ async def get_app_version():
@app.get("/api/version/updates")
async def get_app_latest_release_version(user=Depends(get_verified_user)):
- if OFFLINE_MODE:
+ if not ENABLE_VERSION_UPDATE_CHECK:
log.debug(
- f"Offline mode is enabled, returning current version as latest version"
+ f"Version update check is disabled, returning current version as latest version"
)
return {"current": VERSION, "latest": VERSION}
try:
diff --git a/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py
new file mode 100644
index 0000000000..3c916964e9
--- /dev/null
+++ b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py
@@ -0,0 +1,23 @@
+"""Update folder table data
+
+Revision ID: d31026856c01
+Revises: 9f0c9cd09105
+Create Date: 2025-07-13 03:00:00.000000
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "d31026856c01"
+down_revision = "9f0c9cd09105"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
+
+
+def downgrade():
+ op.drop_column("folder", "data")
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index 0ac53a0233..b9de2e5a4e 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
+from sqlalchemy.sql.expression import bindparam
####################
# Chat DB Schema
@@ -66,12 +67,14 @@ class ChatModel(BaseModel):
class ChatForm(BaseModel):
chat: dict
+ folder_id: Optional[str] = None
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
- folder_id: Optional[str] = None
+ created_at: Optional[int] = None
+ updated_at: Optional[int] = None
class ChatTitleMessagesForm(BaseModel):
@@ -118,6 +121,7 @@ class ChatTable:
else "New Chat"
),
"chat": form_data.chat,
+ "folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
@@ -147,8 +151,16 @@ class ChatTable:
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
- "created_at": int(time.time()),
- "updated_at": int(time.time()),
+ "created_at": (
+ form_data.created_at
+ if form_data.created_at
+ else int(time.time())
+ ),
+ "updated_at": (
+ form_data.updated_at
+ if form_data.updated_at
+ else int(time.time())
+ ),
}
)
@@ -232,6 +244,10 @@ class ChatTable:
if chat is None:
return None
+ # Sanitize message content for null characters before upserting
+ if isinstance(message.get("content"), str):
+ message["content"] = message["content"].replace("\x00", "")
+
chat = chat.chat
history = chat.get("history", {})
@@ -580,7 +596,7 @@ class ChatTable:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
- search_text = search_text.lower().strip()
+ search_text = search_text.replace("\u0000", "").lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(
@@ -614,21 +630,18 @@ class ChatTable:
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
+ sqlite_content_sql = (
+ "EXISTS ("
+ " SELECT 1 "
+ " FROM json_each(Chat.chat, '$.messages') AS message "
+ " WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'"
+ ")"
+ )
+ sqlite_content_clause = text(sqlite_content_sql)
query = query.filter(
- (
- Chat.title.ilike(
- f"%{search_text}%"
- ) # Case-insensitive search in title
- | text(
- """
- EXISTS (
- SELECT 1
- FROM json_each(Chat.chat, '$.messages') AS message
- WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
- )
- """
- )
- ).params(search_text=search_text)
+ or_(
+ Chat.title.ilike(bindparam("title_key")), sqlite_content_clause
+ ).params(title_key=f"%{search_text}%", content_key=search_text)
)
# Check if there are any tags to filter, it should have all the tags
@@ -663,21 +676,19 @@ class ChatTable:
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
+ postgres_content_sql = (
+ "EXISTS ("
+ " SELECT 1 "
+ " FROM json_array_elements(Chat.chat->'messages') AS message "
+ " WHERE LOWER(message->>'content') LIKE '%' || :content_key || '%'"
+ ")"
+ )
+ postgres_content_clause = text(postgres_content_sql)
query = query.filter(
- (
- Chat.title.ilike(
- f"%{search_text}%"
- ) # Case-insensitive search in title
- | text(
- """
- EXISTS (
- SELECT 1
- FROM json_array_elements(Chat.chat->'messages') AS message
- WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
- )
- """
- )
- ).params(search_text=search_text)
+ or_(
+ Chat.title.ilike(bindparam("title_key")),
+ postgres_content_clause,
+ ).params(title_key=f"%{search_text}%", content_key=search_text)
)
# Check if there are any tags to filter, it should have all the tags
diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py
index 1c97de26c9..56ef81f167 100644
--- a/backend/open_webui/models/folders.py
+++ b/backend/open_webui/models/folders.py
@@ -29,6 +29,7 @@ class Folder(Base):
name = Column(Text)
items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
+ data = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@@ -41,6 +42,7 @@ class FolderModel(BaseModel):
name: str
items: Optional[dict] = None
meta: Optional[dict] = None
+ data: Optional[dict] = None
is_expanded: bool = False
created_at: int
updated_at: int
@@ -55,6 +57,7 @@ class FolderModel(BaseModel):
class FolderForm(BaseModel):
name: str
+ data: Optional[dict] = None
model_config = ConfigDict(extra="allow")
@@ -187,8 +190,8 @@ class FolderTable:
log.error(f"update_folder: {e}")
return
- def update_folder_name_by_id_and_user_id(
- self, id: str, user_id: str, name: str
+ def update_folder_by_id_and_user_id(
+ self, id: str, user_id: str, form_data: FolderForm
) -> Optional[FolderModel]:
try:
with get_db() as db:
@@ -197,16 +200,28 @@ class FolderTable:
if not folder:
return None
+ form_data = form_data.model_dump(exclude_unset=True)
+
existing_folder = (
db.query(Folder)
- .filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
+ .filter_by(
+ name=form_data.get("name"),
+ parent_id=folder.parent_id,
+ user_id=user_id,
+ )
.first()
)
- if existing_folder:
+ if existing_folder and existing_folder.id != id:
return None
- folder.name = name
+ folder.name = form_data.get("name", folder.name)
+ if "data" in form_data:
+ folder.data = {
+ **(folder.data or {}),
+ **form_data["data"],
+ }
+
folder.updated_at = int(time.time())
db.commit()
diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py
index 114ccdc574..ce3b9f2e20 100644
--- a/backend/open_webui/models/notes.py
+++ b/backend/open_webui/models/notes.py
@@ -62,6 +62,13 @@ class NoteForm(BaseModel):
access_control: Optional[dict] = None
+class NoteUpdateForm(BaseModel):
+ title: Optional[str] = None
+ data: Optional[dict] = None
+ meta: Optional[dict] = None
+ access_control: Optional[dict] = None
+
+
class NoteUserResponse(NoteModel):
user: Optional[UserResponse] = None
@@ -110,16 +117,26 @@ class NoteTable:
note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None
- def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]:
+ def update_note_by_id(
+ self, id: str, form_data: NoteUpdateForm
+ ) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
if not note:
return None
- note.title = form_data.title
- note.data = form_data.data
- note.meta = form_data.meta
- note.access_control = form_data.access_control
+ form_data = form_data.model_dump(exclude_unset=True)
+
+ if "title" in form_data:
+ note.title = form_data["title"]
+ if "data" in form_data:
+ note.data = {**note.data, **form_data["data"]}
+ if "meta" in form_data:
+ note.meta = {**note.meta, **form_data["meta"]}
+
+ if "access_control" in form_data:
+ note.access_control = form_data["access_control"]
+
note.updated_at = int(time.time_ns())
db.commit()
diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py
index 8ac878fc22..e57323e1eb 100644
--- a/backend/open_webui/retrieval/loaders/main.py
+++ b/backend/open_webui/retrieval/loaders/main.py
@@ -14,7 +14,7 @@ from langchain_community.document_loaders import (
TextLoader,
UnstructuredEPubLoader,
UnstructuredExcelLoader,
- UnstructuredMarkdownLoader,
+ UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
@@ -226,7 +226,10 @@ class Loader:
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or (
- file_content_type and file_content_type.find("text/") >= 0
+ file_content_type
+ and file_content_type.find("text/") >= 0
+ # Avoid text/html files being detected as text
+ and not file_content_type.find("html") >= 0
)
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
@@ -389,6 +392,8 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
+ elif file_ext == "odt":
+ loader = UnstructuredODTLoader(file_path)
elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py
index b00e9d7ce5..b7f2622f5e 100644
--- a/backend/open_webui/retrieval/loaders/mistral.py
+++ b/backend/open_webui/retrieval/loaders/mistral.py
@@ -507,6 +507,7 @@ class MistralLoader:
timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually
+ trust_env=True,
) as session:
yield session
diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py
index 5ebc3e52ea..a9be526b6d 100644
--- a/backend/open_webui/retrieval/models/external.py
+++ b/backend/open_webui/retrieval/models/external.py
@@ -1,8 +1,10 @@
import logging
import requests
from typing import Optional, List, Tuple
+from urllib.parse import quote
-from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
@@ -21,7 +23,9 @@ class ExternalReranker(BaseReranker):
self.url = url
self.model = model
- def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
+ def predict(
+ self, sentences: List[Tuple[str, str]], user=None
+ ) -> Optional[List[float]]:
query = sentences[0][0]
docs = [i[1] for i in sentences]
@@ -41,6 +45,16 @@ class ExternalReranker(BaseReranker):
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
+ **(
+ {
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
+ "X-OpenWebUI-User-Id": user.id,
+ "X-OpenWebUI-User-Email": user.email,
+ "X-OpenWebUI-User-Role": user.role,
+ }
+ if ENABLE_FORWARD_USER_INFO_HEADERS and user
+ else {}
+ ),
},
json=payload,
)
diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py
index 683f42819b..154873749f 100644
--- a/backend/open_webui/retrieval/utils.py
+++ b/backend/open_webui/retrieval/utils.py
@@ -7,6 +7,7 @@ import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
+from urllib.parse import quote
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
@@ -17,8 +18,11 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel
from open_webui.models.files import Files
+from open_webui.models.knowledge import Knowledges
+from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult
+from open_webui.utils.access_control import has_access
from open_webui.env import (
@@ -441,9 +445,20 @@ def get_embedding_function(
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
-def get_sources_from_files(
+def get_reranking_function(reranking_engine, reranking_model, reranking_function):
+ if reranking_function is None:
+ return None
+ if reranking_engine == "external":
+ return lambda sentences, user=None: reranking_function.predict(
+ sentences, user=user
+ )
+ else:
+ return lambda sentences, user=None: reranking_function.predict(sentences)
+
+
+def get_sources_from_items(
request,
- files,
+ items,
queries,
embedding_function,
k,
@@ -453,159 +468,210 @@ def get_sources_from_files(
hybrid_bm25_weight,
hybrid_search,
full_context=False,
+ user: Optional[UserModel] = None,
):
log.debug(
- f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
+ f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
- relevant_contexts = []
+ query_results = []
- for file in files:
+ for item in items:
+ query_result = None
+ collection_names = []
- context = None
- if file.get("docs"):
- # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
- context = {
- "documents": [[doc.get("content") for doc in file.get("docs")]],
- "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
- }
- elif file.get("context") == "full":
- # Manual Full Mode Toggle
- context = {
- "documents": [[file.get("file").get("data", {}).get("content")]],
- "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
- }
- elif (
- file.get("type") != "web_search"
- and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
- ):
- # BYPASS_EMBEDDING_AND_RETRIEVAL
- if file.get("type") == "collection":
- file_ids = file.get("data", {}).get("file_ids", [])
+ if item.get("type") == "text":
+ # Raw Text
+ # Used during temporary chat file uploads
- documents = []
- metadatas = []
- for file_id in file_ids:
- file_object = Files.get_file_by_id(file_id)
-
- if file_object:
- documents.append(file_object.data.get("content", ""))
- metadatas.append(
- {
- "file_id": file_id,
- "name": file_object.filename,
- "source": file_object.filename,
- }
- )
-
- context = {
- "documents": [documents],
- "metadatas": [metadatas],
+ if item.get("file"):
+ # if item has file data, use it
+ query_result = {
+ "documents": [
+ [item.get("file", {}).get("data", {}).get("content")]
+ ],
+ "metadatas": [
+ [item.get("file", {}).get("data", {}).get("meta", {})]
+ ],
+ }
+ else:
+ # Fallback to item content
+ query_result = {
+ "documents": [[item.get("content")]],
+ "metadatas": [
+ [{"file_id": item.get("id"), "name": item.get("name")}]
+ ],
}
- elif file.get("id"):
- file_object = Files.get_file_by_id(file.get("id"))
- if file_object:
- context = {
- "documents": [[file_object.data.get("content", "")]],
+ elif item.get("type") == "note":
+ # Note Attached
+ note = Notes.get_note_by_id(item.get("id"))
+
+ if user.role == "admin" or has_access(user.id, "read", note.access_control):
+ # User has access to the note
+ query_result = {
+ "documents": [[note.data.get("content", {}).get("md", "")]],
+ "metadatas": [[{"file_id": note.id, "name": note.title}]],
+ }
+
+ elif item.get("type") == "file":
+ if (
+ item.get("context") == "full"
+ or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
+ ):
+ if item.get("file", {}).get("data", {}).get("content", ""):
+ # Manual Full Mode Toggle
+ # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
+ query_result = {
+ "documents": [
+ [item.get("file", {}).get("data", {}).get("content", "")]
+ ],
"metadatas": [
[
{
- "file_id": file.get("id"),
- "name": file_object.filename,
- "source": file_object.filename,
+ "file_id": item.get("id"),
+ "name": item.get("name"),
+ **item.get("file")
+ .get("data", {})
+ .get("metadata", {}),
}
]
],
}
- elif file.get("file").get("data"):
- context = {
- "documents": [[file.get("file").get("data", {}).get("content")]],
- "metadatas": [
- [file.get("file").get("data", {}).get("metadata", {})]
- ],
- }
- else:
- collection_names = []
- if file.get("type") == "collection":
- if file.get("legacy"):
- collection_names = file.get("collection_names", [])
+ elif item.get("id"):
+ file_object = Files.get_file_by_id(item.get("id"))
+ if file_object:
+ query_result = {
+ "documents": [[file_object.data.get("content", "")]],
+ "metadatas": [
+ [
+ {
+ "file_id": item.get("id"),
+ "name": file_object.filename,
+ "source": file_object.filename,
+ }
+ ]
+ ],
+ }
+ else:
+ # Fallback to collection names
+ if item.get("legacy"):
+ collection_names.append(f"{item['id']}")
else:
- collection_names.append(file["id"])
- elif file.get("collection_name"):
- collection_names.append(file["collection_name"])
- elif file.get("id"):
- if file.get("legacy"):
- collection_names.append(f"{file['id']}")
- else:
- collection_names.append(f"file-{file['id']}")
+ collection_names.append(f"file-{item['id']}")
+ elif item.get("type") == "collection":
+ if (
+ item.get("context") == "full"
+ or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
+ ):
+ # Manual Full Mode Toggle for Collection
+ knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
+
+ if knowledge_base and (
+ user.role == "admin"
+ or has_access(user.id, "read", knowledge_base.access_control)
+ ):
+
+ file_ids = knowledge_base.data.get("file_ids", [])
+
+ documents = []
+ metadatas = []
+ for file_id in file_ids:
+ file_object = Files.get_file_by_id(file_id)
+
+ if file_object:
+ documents.append(file_object.data.get("content", ""))
+ metadatas.append(
+ {
+ "file_id": file_id,
+ "name": file_object.filename,
+ "source": file_object.filename,
+ }
+ )
+
+ query_result = {
+ "documents": [documents],
+ "metadatas": [metadatas],
+ }
+ else:
+ # Fallback to collection names
+ if item.get("legacy"):
+ collection_names = item.get("collection_names", [])
+ else:
+ collection_names.append(item["id"])
+
+ elif item.get("docs"):
+ # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
+ query_result = {
+ "documents": [[doc.get("content") for doc in item.get("docs")]],
+ "metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
+ }
+ elif item.get("collection_name"):
+ # Direct Collection Name
+ collection_names.append(item["collection_name"])
+
+ # If query_result is None
+ # Fallback to collection names and vector search the collections
+ if query_result is None and collection_names:
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
- log.debug(f"skipping {file} as it has already been extracted")
+ log.debug(f"skipping {item} as it has already been extracted")
continue
- if full_context:
- try:
- context = get_all_items_from_collections(collection_names)
- except Exception as e:
- log.exception(e)
-
- else:
- try:
- context = None
- if file.get("type") == "text":
- context = file["content"]
- else:
- if hybrid_search:
- try:
- context = query_collection_with_hybrid_search(
- collection_names=collection_names,
- queries=queries,
- embedding_function=embedding_function,
- k=k,
- reranking_function=reranking_function,
- k_reranker=k_reranker,
- r=r,
- hybrid_bm25_weight=hybrid_bm25_weight,
- )
- except Exception as e:
- log.debug(
- "Error when using hybrid search, using"
- " non hybrid search as fallback."
- )
-
- if (not hybrid_search) or (context is None):
- context = query_collection(
+ try:
+ if full_context:
+ query_result = get_all_items_from_collections(collection_names)
+ else:
+ query_result = None # Initialize to None
+ if hybrid_search:
+ try:
+ query_result = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
+ reranking_function=reranking_function,
+ k_reranker=k_reranker,
+ r=r,
+ hybrid_bm25_weight=hybrid_bm25_weight,
)
- except Exception as e:
- log.exception(e)
+ except Exception as e:
+ log.debug(
+ "Error when using hybrid search, using non hybrid search as fallback."
+ )
+
+ # fallback to non-hybrid search
+ if not hybrid_search and query_result is None:
+ query_result = query_collection(
+ collection_names=collection_names,
+ queries=queries,
+ embedding_function=embedding_function,
+ k=k,
+ )
+ except Exception as e:
+ log.exception(e)
extracted_collections.extend(collection_names)
- if context:
- if "data" in file:
- del file["data"]
-
- relevant_contexts.append({**context, "file": file})
+ if query_result:
+ if "data" in item:
+ del item["data"]
+ query_results.append({**query_result, "file": item})
sources = []
- for context in relevant_contexts:
+ for query_result in query_results:
try:
- if "documents" in context:
- if "metadatas" in context:
+ if "documents" in query_result:
+ if "metadatas" in query_result:
source = {
- "source": context["file"],
- "document": context["documents"][0],
- "metadata": context["metadatas"][0],
+ "source": query_result["file"],
+ "document": query_result["documents"][0],
+ "metadata": query_result["metadatas"][0],
}
- if "distances" in context and context["distances"]:
- source["distances"] = context["distances"][0]
+ if "distances" in query_result and query_result["distances"]:
+ source["distances"] = query_result["distances"][0]
sources.append(source)
except Exception as e:
@@ -678,7 +744,7 @@ def generate_openai_batch_embeddings(
"Authorization": f"Bearer {key}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -727,7 +793,7 @@ def generate_azure_openai_batch_embeddings(
"api-key": key,
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -777,7 +843,7 @@ def generate_ollama_batch_embeddings(
"Authorization": f"Bearer {key}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -874,7 +940,7 @@ class RerankCompressor(BaseDocumentCompressor):
reranking = self.reranking_function is not None
if reranking:
- scores = self.reranking_function.predict(
+ scores = self.reranking_function(
[(query, doc.page_content) for doc in documents]
)
else:
diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py
index 60ef2d906c..7e16df3cfb 100644
--- a/backend/open_webui/retrieval/vector/dbs/opensearch.py
+++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py
@@ -157,10 +157,10 @@ class OpenSearchClient(VectorDBBase):
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
- size = limit if limit else 10
+ size = limit if limit else 10000
try:
result = self.client.search(
@@ -206,6 +206,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
@@ -228,6 +229,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def delete(
self,
@@ -251,11 +253,12 @@ class OpenSearchClient(VectorDBBase):
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py
index 632937ef5b..64f12aa6d0 100644
--- a/backend/open_webui/retrieval/vector/dbs/pgvector.py
+++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py
@@ -18,7 +18,7 @@ from sqlalchemy import (
values,
)
from sqlalchemy.sql import true
-from sqlalchemy.pool import NullPool
+from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
@@ -37,6 +37,10 @@ from open_webui.config import (
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
+ PGVECTOR_POOL_SIZE,
+ PGVECTOR_POOL_MAX_OVERFLOW,
+ PGVECTOR_POOL_TIMEOUT,
+ PGVECTOR_POOL_RECYCLE,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -80,9 +84,24 @@ class PgvectorClient(VectorDBBase):
self.session = Session
else:
- engine = create_engine(
- PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
- )
+ if isinstance(PGVECTOR_POOL_SIZE, int):
+ if PGVECTOR_POOL_SIZE > 0:
+ engine = create_engine(
+ PGVECTOR_DB_URL,
+ pool_size=PGVECTOR_POOL_SIZE,
+ max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
+ pool_timeout=PGVECTOR_POOL_TIMEOUT,
+ pool_recycle=PGVECTOR_POOL_RECYCLE,
+ pool_pre_ping=True,
+ poolclass=QueuePool,
+ )
+ else:
+ engine = create_engine(
+ PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
+ )
+ else:
+ engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
+
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py
index dfe2979076..2276e713fc 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py
@@ -18,6 +18,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -29,7 +30,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
@@ -86,6 +87,25 @@ class QdrantClient(VectorDBBase):
),
)
+ # Create payload indexes for efficient filtering
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.hash",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.file_id",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
index e83c437ef7..17c054ee50 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
@@ -1,5 +1,5 @@
import logging
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List, Dict, Any
from urllib.parse import urlparse
import grpc
@@ -9,6 +9,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
@@ -23,14 +24,28 @@ from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
NO_LIMIT = 999999999
+TENANT_ID_FIELD = "tenant_id"
+DEFAULT_DIMENSION = 384
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
+def _tenant_filter(tenant_id: str) -> models.FieldCondition:
+ return models.FieldCondition(
+ key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
+ )
+
+
+def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
+ return models.FieldCondition(
+ key=f"metadata.{key}", match=models.MatchValue(value=value)
+ )
+
+
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
@@ -38,24 +53,26 @@ class QdrantClient(VectorDBBase):
self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI:
- self.client = None
- return
+ raise ValueError(
+ "QDRANT_URI is not set. Please configure it in the environment variables."
+ )
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
- if self.PREFER_GRPC:
- self.client = Qclient(
+ self.client = (
+ Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
)
- else:
- self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
+ if self.PREFER_GRPC
+ else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
+ )
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
@@ -65,23 +82,13 @@ class QdrantClient(VectorDBBase):
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult:
- ids = []
- documents = []
- metadatas = []
-
+ ids, documents, metadatas = [], [], []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
-
- return GetResult(
- **{
- "ids": [ids],
- "documents": [documents],
- "metadatas": [metadatas],
- }
- )
+ return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
"""
@@ -113,143 +120,47 @@ class QdrantClient(VectorDBBase):
else:
return self.KNOWLEDGE_COLLECTION, tenant_id
- def _extract_error_message(self, exception):
- """
- Extract error message from either HTTP or gRPC exceptions
-
- Returns:
- tuple: (status_code, error_message)
- """
- # Check if it's an HTTP exception
- if isinstance(exception, UnexpectedResponse):
- try:
- error_data = exception.structured()
- error_msg = error_data.get("status", {}).get("error", "")
- return exception.status_code, error_msg
- except Exception as inner_e:
- log.error(f"Failed to parse HTTP error: {inner_e}")
- return exception.status_code, str(exception)
-
- # Check if it's a gRPC exception
- elif isinstance(exception, grpc.RpcError):
- # Extract status code from gRPC error
- status_code = None
- if hasattr(exception, "code") and callable(exception.code):
- status_code = exception.code().value[0]
-
- # Extract error message
- error_msg = str(exception)
- if "details =" in error_msg:
- # Parse the details line which contains the actual error message
- try:
- details_line = [
- line.strip()
- for line in error_msg.split("\n")
- if "details =" in line
- ][0]
- error_msg = details_line.split("details =")[1].strip(' "')
- except (IndexError, AttributeError):
- # Fall back to full message if parsing fails
- pass
-
- return status_code, error_msg
-
- # For any other type of exception
- return None, str(exception)
-
- def _is_collection_not_found_error(self, exception):
- """
- Check if the exception is due to collection not found, supporting both HTTP and gRPC
- """
- status_code, error_msg = self._extract_error_message(exception)
-
- # HTTP error (404)
- if (
- status_code == 404
- and "Collection" in error_msg
- and "doesn't exist" in error_msg
- ):
- return True
-
- # gRPC error (NOT_FOUND status)
- if (
- isinstance(exception, grpc.RpcError)
- and exception.code() == grpc.StatusCode.NOT_FOUND
- ):
- return True
-
- return False
-
- def _is_dimension_mismatch_error(self, exception):
- """
- Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
- """
- status_code, error_msg = self._extract_error_message(exception)
-
- # Common patterns in both HTTP and gRPC
- return (
- "Vector dimension error" in error_msg
- or "dimensions mismatch" in error_msg
- or "invalid vector size" in error_msg
- )
-
- def _create_multi_tenant_collection_if_not_exists(
- self, mt_collection_name: str, dimension: int = 384
+ def _create_multi_tenant_collection(
+ self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
):
"""
- Creates a collection with multi-tenancy configuration if it doesn't exist.
- Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
- When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
+ Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
"""
- try:
- # Try to create the collection directly - will fail if it already exists
- self.client.create_collection(
- collection_name=mt_collection_name,
- vectors_config=models.VectorParams(
- size=dimension,
- distance=models.Distance.COSINE,
- on_disk=self.QDRANT_ON_DISK,
- ),
- hnsw_config=models.HnswConfigDiff(
- payload_m=16, # Enable per-tenant indexing
- m=0,
- on_disk=self.QDRANT_ON_DISK,
- ),
- )
+ self.client.create_collection(
+ collection_name=mt_collection_name,
+ vectors_config=models.VectorParams(
+ size=dimension,
+ distance=models.Distance.COSINE,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+ log.info(
+ f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
+ )
- # Create tenant ID payload index
+ self.client.create_payload_index(
+ collection_name=mt_collection_name,
+ field_name=TENANT_ID_FIELD,
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=True,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+
+ for field in ("metadata.hash", "metadata.file_id"):
self.client.create_payload_index(
collection_name=mt_collection_name,
- field_name="tenant_id",
+ field_name=field,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
- is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
- wait=True,
)
- log.info(
- f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
- )
- except (UnexpectedResponse, grpc.RpcError) as e:
- # Check for the specific error indicating collection already exists
- status_code, error_msg = self._extract_error_message(e)
-
- # HTTP status code 409 or gRPC ALREADY_EXISTS
- if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
- isinstance(e, grpc.RpcError)
- and e.code() == grpc.StatusCode.ALREADY_EXISTS
- ):
- if "already exists" in error_msg:
- log.debug(f"Collection {mt_collection_name} already exists")
- return
- # If it's not an already exists error, re-raise
- raise e
- except Exception as e:
- raise e
-
- def _create_points(self, items: list[VectorItem], tenant_id: str):
+ def _create_points(
+ self, items: List[VectorItem], tenant_id: str
+ ) -> List[PointStruct]:
"""
Create point structs from vector items with tenant ID.
"""
@@ -260,56 +171,42 @@ class QdrantClient(VectorDBBase):
payload={
"text": item["text"],
"metadata": item["metadata"],
- "tenant_id": tenant_id,
+ TENANT_ID_FIELD: tenant_id,
},
)
for item in items
]
+ def _ensure_collection(
+ self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
+ ):
+ """
+ Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
+ """
+ if not self.client.collection_exists(collection_name=mt_collection_name):
+ self._create_multi_tenant_collection(mt_collection_name, dimension)
+
def has_collection(self, collection_name: str) -> bool:
"""
Check if a logical collection exists by checking for any points with the tenant ID.
"""
if not self.client:
return False
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- try:
- # Try directly querying - most of the time collection should exist
- response = self.client.query_points(
- collection_name=mt_collection,
- query_filter=models.Filter(must=[tenant_filter]),
- limit=1,
- )
-
- # Collection exists with this tenant ID if there are points
- return len(response.points) > 0
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(f"Collection {mt_collection} doesn't exist")
- return False
- else:
- # For other API errors, log and return False
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error: {error_msg}")
- return False
- except Exception as e:
- # For any other errors, log and return False
- log.debug(f"Error checking collection {mt_collection}: {e}")
+ if not self.client.collection_exists(collection_name=mt_collection):
return False
+ tenant_filter = _tenant_filter(tenant_id)
+ count_result = self.client.count(
+ collection_name=mt_collection,
+ count_filter=models.Filter(must=[tenant_filter]),
+ )
+ return count_result.count > 0
def delete(
self,
collection_name: str,
- ids: Optional[list[str]] = None,
- filter: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ filter: Optional[Dict[str, Any]] = None,
):
"""
Delete vectors by ID or filter from a collection with tenant isolation.
@@ -317,189 +214,76 @@ class QdrantClient(VectorDBBase):
if not self.client:
return None
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
+ return None
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
+ must_conditions = [_tenant_filter(tenant_id)]
+ should_conditions = []
+ if ids:
+ should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
+ elif filter:
+ must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
+
+ return self.client.delete(
+ collection_name=mt_collection,
+ points_selector=models.FilterSelector(
+ filter=models.Filter(must=must_conditions, should=should_conditions)
+ ),
)
- must_conditions = [tenant_filter]
- should_conditions = []
-
- if ids:
- for id_value in ids:
- should_conditions.append(
- models.FieldCondition(
- key="metadata.id",
- match=models.MatchValue(value=id_value),
- ),
- )
- elif filter:
- for key, value in filter.items():
- must_conditions.append(
- models.FieldCondition(
- key=f"metadata.{key}",
- match=models.MatchValue(value=value),
- ),
- )
-
- try:
- # Try to delete directly - most of the time collection should exist
- update_result = self.client.delete(
- collection_name=mt_collection,
- points_selector=models.FilterSelector(
- filter=models.Filter(must=must_conditions, should=should_conditions)
- ),
- )
-
- return update_result
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(
- f"Collection {mt_collection} doesn't exist, nothing to delete"
- )
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, re-raise
- raise
-
def search(
- self, collection_name: str, vectors: list[list[float | int]], limit: int
+ self, collection_name: str, vectors: List[List[float | int]], limit: int
) -> Optional[SearchResult]:
"""
Search for the nearest neighbor items based on the vectors with tenant isolation.
"""
- if not self.client:
+ if not self.client or not vectors:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Get the vector dimension from the query vector
- dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
-
- try:
- # Try the search operation directly - most of the time collection should exist
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- # Ensure vector dimensions match the collection
- collection_dim = self.client.get_collection(
- mt_collection
- ).config.params.vectors.size
-
- if collection_dim != dimension:
- if collection_dim < dimension:
- vectors = [vector[:collection_dim] for vector in vectors]
- else:
- vectors = [
- vector + [0] * (collection_dim - dimension)
- for vector in vectors
- ]
-
- # Search with tenant filter
- prefetch_query = models.Prefetch(
- filter=models.Filter(must=[tenant_filter]),
- limit=NO_LIMIT,
- )
- query_response = self.client.query_points(
- collection_name=mt_collection,
- query=vectors[0],
- prefetch=prefetch_query,
- limit=limit,
- )
-
- get_result = self._result_to_get_result(query_response.points)
- return SearchResult(
- ids=get_result.ids,
- documents=get_result.documents,
- metadatas=get_result.metadatas,
- # qdrant distance is [-1, 1], normalize to [0, 1]
- distances=[
- [(point.score + 1.0) / 2.0 for point in query_response.points]
- ],
- )
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(
- f"Collection {mt_collection} doesn't exist, search returns None"
- )
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error during search: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, log and return None
- log.exception(f"Error searching collection '{collection_name}': {e}")
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
return None
- def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
+ tenant_filter = _tenant_filter(tenant_id)
+ query_response = self.client.query_points(
+ collection_name=mt_collection,
+ query=vectors[0],
+ limit=limit,
+ query_filter=models.Filter(must=[tenant_filter]),
+ )
+ get_result = self._result_to_get_result(query_response.points)
+ return SearchResult(
+ ids=get_result.ids,
+ documents=get_result.documents,
+ metadatas=get_result.metadatas,
+ distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
+ )
+
+ def query(
+ self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
+ ):
"""
Query points with filters and tenant isolation.
"""
if not self.client:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Set default limit if not provided
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
+ return None
if limit is None:
limit = NO_LIMIT
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- # Create metadata filters
- field_conditions = []
- for key, value in filter.items():
- field_conditions.append(
- models.FieldCondition(
- key=f"metadata.{key}", match=models.MatchValue(value=value)
- )
- )
-
- # Combine tenant filter with metadata filters
+ tenant_filter = _tenant_filter(tenant_id)
+ field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
-
- try:
- # Try the query directly - most of the time collection should exist
- points = self.client.query_points(
- collection_name=mt_collection,
- query_filter=combined_filter,
- limit=limit,
- )
-
- return self._result_to_get_result(points.points)
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(
- f"Collection {mt_collection} doesn't exist, query returns None"
- )
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error during query: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, log and re-raise
- log.exception(f"Error querying collection '{collection_name}': {e}")
- return None
+ points = self.client.query_points(
+ collection_name=mt_collection,
+ query_filter=combined_filter,
+ limit=limit,
+ )
+ return self._result_to_get_result(points.points)
def get(self, collection_name: str) -> Optional[GetResult]:
"""
@@ -507,169 +291,36 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- try:
- # Try to get points directly - most of the time collection should exist
- points = self.client.query_points(
- collection_name=mt_collection,
- query_filter=models.Filter(must=[tenant_filter]),
- limit=NO_LIMIT,
- )
-
- return self._result_to_get_result(points.points)
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error during get: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, log and return None
- log.exception(f"Error getting collection '{collection_name}': {e}")
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
-
- def _handle_operation_with_error_retry(
- self, operation_name, mt_collection, points, dimension
- ):
- """
- Private helper to handle common error cases for insert and upsert operations.
-
- Args:
- operation_name: 'insert' or 'upsert'
- mt_collection: The multi-tenant collection name
- points: The vector points to insert/upsert
- dimension: The dimension of the vectors
-
- Returns:
- The operation result (for upsert) or None (for insert)
- """
- try:
- if operation_name == "insert":
- self.client.upload_points(mt_collection, points)
- return None
- else: # upsert
- return self.client.upsert(mt_collection, points)
- except (UnexpectedResponse, grpc.RpcError) as e:
- # Handle collection not found
- if self._is_collection_not_found_error(e):
- log.info(
- f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
- )
- # Create collection with correct dimensions from our vectors
- self._create_multi_tenant_collection_if_not_exists(
- mt_collection_name=mt_collection, dimension=dimension
- )
- # Try operation again - no need for dimension adjustment since we just created with correct dimensions
- if operation_name == "insert":
- self.client.upload_points(mt_collection, points)
- return None
- else: # upsert
- return self.client.upsert(mt_collection, points)
-
- # Handle dimension mismatch
- elif self._is_dimension_mismatch_error(e):
- # For dimension errors, the collection must exist, so get its configuration
- mt_collection_info = self.client.get_collection(mt_collection)
- existing_size = mt_collection_info.config.params.vectors.size
-
- log.info(
- f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
- )
-
- if existing_size < dimension:
- # Truncate vectors to fit
- log.info(
- f"Truncating vectors from {dimension} to {existing_size} dimensions"
- )
- points = [
- PointStruct(
- id=point.id,
- vector=point.vector[:existing_size],
- payload=point.payload,
- )
- for point in points
- ]
- elif existing_size > dimension:
- # Pad vectors with zeros
- log.info(
- f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
- )
- points = [
- PointStruct(
- id=point.id,
- vector=point.vector
- + [0] * (existing_size - len(point.vector)),
- payload=point.payload,
- )
- for point in points
- ]
- # Try operation again with adjusted dimensions
- if operation_name == "insert":
- self.client.upload_points(mt_collection, points)
- return None
- else: # upsert
- return self.client.upsert(mt_collection, points)
- else:
- # Not a known error we can handle, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unhandled Qdrant error: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, re-raise
- raise
-
- def insert(self, collection_name: str, items: list[VectorItem]):
- """
- Insert items with tenant ID.
- """
- if not self.client or not items:
- return None
-
- # Map to multi-tenant collection and tenant ID
- mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Get dimensions from the actual vectors
- dimension = len(items[0]["vector"]) if items else None
-
- # Create points with tenant ID
- points = self._create_points(items, tenant_id)
-
- # Handle the operation with error retry
- return self._handle_operation_with_error_retry(
- "insert", mt_collection, points, dimension
+ tenant_filter = _tenant_filter(tenant_id)
+ points = self.client.query_points(
+ collection_name=mt_collection,
+ query_filter=models.Filter(must=[tenant_filter]),
+ limit=NO_LIMIT,
)
+ return self._result_to_get_result(points.points)
- def upsert(self, collection_name: str, items: list[VectorItem]):
+ def upsert(self, collection_name: str, items: List[VectorItem]):
"""
Upsert items with tenant ID.
"""
if not self.client or not items:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Get dimensions from the actual vectors
- dimension = len(items[0]["vector"]) if items else None
-
- # Create points with tenant ID
+ dimension = len(items[0]["vector"])
+ self._ensure_collection(mt_collection, dimension)
points = self._create_points(items, tenant_id)
+ self.client.upload_points(mt_collection, points)
+ return None
- # Handle the operation with error retry
- return self._handle_operation_with_error_retry(
- "upsert", mt_collection, points, dimension
- )
+ def insert(self, collection_name: str, items: List[VectorItem]):
+ """
+ Insert items with tenant ID.
+ """
+ return self.upsert(collection_name, items)
def reset(self):
"""
@@ -677,11 +328,9 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
-
- collection_names = self.client.get_collections().collections
- for collection_name in collection_names:
- if collection_name.name.startswith(self.collection_prefix):
- self.client.delete_collection(collection_name=collection_name.name)
+ for collection in self.client.get_collections().collections:
+ if collection.name.startswith(self.collection_prefix):
+ self.client.delete_collection(collection_name=collection.name)
def delete_collection(self, collection_name: str):
"""
@@ -689,24 +338,13 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- field_conditions = [tenant_filter]
-
- update_result = self.client.delete(
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
+ return None
+ self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
- filter=models.Filter(must=field_conditions)
+ filter=models.Filter(must=[_tenant_filter(tenant_id)])
),
)
-
- if self.client.get_collection(mt_collection).points_count == 0:
- self.client.delete_collection(mt_collection)
-
- return update_result
diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py
index 3075db990f..7bea575620 100644
--- a/backend/open_webui/retrieval/web/brave.py
+++ b/backend/open_webui/retrieval/web/brave.py
@@ -36,7 +36,9 @@ def search_brave(
return [
SearchResult(
- link=result["url"], title=result.get("title"), snippet=result.get("snippet")
+ link=result["url"],
+ title=result.get("title"),
+ snippet=result.get("description"),
)
for result in results[:count]
]
diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py
index bf8ae6880b..a32fc358ed 100644
--- a/backend/open_webui/retrieval/web/duckduckgo.py
+++ b/backend/open_webui/retrieval/web/duckduckgo.py
@@ -2,8 +2,8 @@ import logging
from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
-from duckduckgo_search import DDGS
-from duckduckgo_search.exceptions import RatelimitException
+from ddgs import DDGS
+from ddgs.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py
index 27634cec19..c63ad3bfe7 100644
--- a/backend/open_webui/routers/audio.py
+++ b/backend/open_webui/routers/audio.py
@@ -15,6 +15,7 @@ import aiohttp
import aiofiles
import requests
import mimetypes
+from urllib.parse import quote
from fastapi import (
Depends,
@@ -327,6 +328,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
+ r = None
if request.app.state.config.TTS_ENGINE == "openai":
payload["model"] = request.app.state.config.TTS_MODEL
@@ -335,7 +337,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
- async with session.post(
+ r = await session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload,
headers={
@@ -343,7 +345,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -353,14 +355,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as r:
- r.raise_for_status()
+ )
- async with aiofiles.open(file_path, "wb") as f:
- await f.write(await r.read())
+ r.raise_for_status()
- async with aiofiles.open(file_body_path, "w") as f:
- await f.write(json.dumps(payload))
+ async with aiofiles.open(file_path, "wb") as f:
+ await f.write(await r.read())
+
+ async with aiofiles.open(file_body_path, "w") as f:
+ await f.write(json.dumps(payload))
return FileResponse(file_path)
@@ -368,18 +371,18 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e)
detail = None
- try:
- if r.status != 200:
- res = await r.json()
+ status_code = 500
+ detail = f"Open WebUI: Server Connection Error"
- if "error" in res:
- detail = f"External: {res['error'].get('message', '')}"
- except Exception:
- detail = f"External: {e}"
+ if r is not None:
+ status_code = r.status
+ res = await r.json()
+ if "error" in res:
+ detail = f"External: {res['error'].get('message', '')}"
raise HTTPException(
- status_code=getattr(r, "status", 500) if r else 500,
- detail=detail if detail else "Open WebUI: Server Connection Error",
+ status_code=status_code,
+ detail=detail,
)
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
@@ -919,14 +922,18 @@ def transcription(
):
log.info(f"file.content_type: {file.content_type}")
- supported_content_types = request.app.state.config.STT_SUPPORTED_CONTENT_TYPES or [
- "audio/*",
- "video/webm",
- ]
+ stt_supported_content_types = getattr(
+ request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
+ )
if not any(
fnmatch(file.content_type, content_type)
- for content_type in supported_content_types
+ for content_type in (
+ stt_supported_content_types
+ if stt_supported_content_types
+ and any(t.strip() for t in stt_supported_content_types)
+ else ["audio/*", "video/webm"]
+ )
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py
index 60a12db4b3..fd3c317ab9 100644
--- a/backend/open_webui/routers/auths.py
+++ b/backend/open_webui/routers/auths.py
@@ -669,12 +669,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout")
async def signout(request: Request, response: Response):
response.delete_cookie("token")
+ response.delete_cookie("oui-session")
if ENABLE_OAUTH_SIGNUP.value:
oauth_id_token = request.cookies.get("oauth_id_token")
if oauth_id_token:
try:
- async with ClientSession() as session:
+ async with ClientSession(trust_env=True) as session:
async with session.get(OPENID_PROVIDER_URL.value) as resp:
if resp.status == 200:
openid_data = await resp.json()
@@ -686,7 +687,12 @@ async def signout(request: Request, response: Response):
status_code=200,
content={
"status": True,
- "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
+ "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}"
+ + (
+ f"&post_logout_redirect_uri={WEBUI_AUTH_SIGNOUT_REDIRECT_URL}"
+ if WEBUI_AUTH_SIGNOUT_REDIRECT_URL
+ else ""
+ ),
},
headers=response.headers,
)
diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py
index 6da3f04cee..a4173fbd8d 100644
--- a/backend/open_webui/routers/channels.py
+++ b/backend/open_webui/routers/channels.py
@@ -40,10 +40,14 @@ router = APIRouter()
@router.get("/", response_model=list[ChannelModel])
async def get_channels(user=Depends(get_verified_user)):
+ return Channels.get_channels_by_user_id(user.id)
+
+
+@router.get("/list", response_model=list[ChannelModel])
+async def get_all_channels(user=Depends(get_verified_user)):
if user.role == "admin":
return Channels.get_channels()
- else:
- return Channels.get_channels_by_user_id(user.id)
+ return Channels.get_channels_by_user_id(user.id)
############################
diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py
index 29b12ed676..13b6040102 100644
--- a/backend/open_webui/routers/chats.py
+++ b/backend/open_webui/routers/chats.py
@@ -684,8 +684,10 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
- if not has_permission(
- user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ if (user.role != "admin") and (
+ not has_permission(
+ user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ )
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index 44b2ef40cf..a329584ca2 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -39,32 +39,39 @@ async def export_config(user=Depends(get_admin_user)):
############################
-# Direct Connections Config
+# Connections Config
############################
-class DirectConnectionsConfigForm(BaseModel):
+class ConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool
+ ENABLE_BASE_MODELS_CACHE: bool
-@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+@router.get("/connections", response_model=ConnectionsConfigForm)
+async def get_connections_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
-@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def set_direct_connections_config(
+@router.post("/connections", response_model=ConnectionsConfigForm)
+async def set_connections_config(
request: Request,
- form_data: DirectConnectionsConfigForm,
+ form_data: ConnectionsConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS
)
+ request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
+ form_data.ENABLE_BASE_MODELS_CACHE
+ )
+
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py
index b9bb15c7b4..bdf5780fc4 100644
--- a/backend/open_webui/routers/files.py
+++ b/backend/open_webui/routers/files.py
@@ -155,17 +155,18 @@ def upload_file(
if process:
try:
if file.content_type:
- stt_supported_content_types = (
- request.app.state.config.STT_SUPPORTED_CONTENT_TYPES
- or [
- "audio/*",
- "video/webm",
- ]
+ stt_supported_content_types = getattr(
+ request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
)
if any(
fnmatch(file.content_type, content_type)
- for content_type in stt_supported_content_types
+ for content_type in (
+ stt_supported_content_types
+ if stt_supported_content_types
+ and any(t.strip() for t in stt_supported_content_types)
+ else ["audio/*", "video/webm"]
+ )
):
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata)
diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py
index 2c41c92854..edc9f85ff2 100644
--- a/backend/open_webui/routers/folders.py
+++ b/backend/open_webui/routers/folders.py
@@ -120,16 +120,14 @@ async def update_folder_name_by_id(
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
folder.parent_id, user.id, form_data.name
)
- if existing_folder:
+ if existing_folder and existing_folder.id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
- folder = Folders.update_folder_name_by_id_and_user_id(
- id, user.id, form_data.name
- )
+ folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data)
return folder
except Exception as e:
diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py
index 355093335a..96d8215fb3 100644
--- a/backend/open_webui/routers/functions.py
+++ b/backend/open_webui/routers/functions.py
@@ -105,7 +105,7 @@ async def load_function_from_url(
)
try:
- async with aiohttp.ClientSession() as session:
+ async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index 52686a5841..5cd07e3d54 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -8,6 +8,7 @@ import re
from pathlib import Path
from typing import Optional
+from urllib.parse import quote
import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
@@ -302,8 +303,16 @@ async def update_image_config(
):
set_image_model(request, form_data.MODEL)
+ if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1":
+ raise HTTPException(
+ status_code=400,
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(
+ " (auto is only allowed with gpt-image-1)."
+ ),
+ )
+
pattern = r"^\d+x\d+$"
- if re.match(pattern, form_data.IMAGE_SIZE):
+ if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
@@ -471,7 +480,14 @@ async def image_generations(
form_data: GenerateImageForm,
user=Depends(get_verified_user),
):
- width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
+ # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
+ # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
+ # image model other than gpt-image-1, which is warned about on settings save
+ width, height = (
+ tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
+ if "x" in request.app.state.config.IMAGE_SIZE
+ else (512, 512)
+ )
r = None
try:
@@ -483,7 +499,7 @@ async def image_generations(
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
- headers["X-OpenWebUI-User-Name"] = user.name
+ headers["X-OpenWebUI-User-Name"] = quote(user.name, safe=" ")
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py
index 2cbbd331b5..0c4842909f 100644
--- a/backend/open_webui/routers/notes.py
+++ b/backend/open_webui/routers/notes.py
@@ -51,7 +51,14 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
return notes
-@router.get("/list", response_model=list[NoteUserResponse])
+class NoteTitleIdResponse(BaseModel):
+ id: str
+ title: str
+ updated_at: int
+ created_at: int
+
+
+@router.get("/list", response_model=list[NoteTitleIdResponse])
async def get_note_list(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
@@ -63,13 +70,8 @@ async def get_note_list(request: Request, user=Depends(get_verified_user)):
)
notes = [
- NoteUserResponse(
- **{
- **note.model_dump(),
- "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
- }
- )
- for note in Notes.get_notes_by_user_id(user.id, "read")
+ NoteTitleIdResponse(**note.model_dump())
+ for note in Notes.get_notes_by_user_id(user.id, "write")
]
return notes
diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py
index 1353599374..000f4b48d2 100644
--- a/backend/open_webui/routers/ollama.py
+++ b/backend/open_webui/routers/ollama.py
@@ -16,6 +16,7 @@ from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
+from urllib.parse import quote
from open_webui.models.chats import Chats
from open_webui.models.users import UserModel
@@ -58,6 +59,7 @@ from open_webui.config import (
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -87,7 +89,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -138,7 +140,7 @@ async def send_post_request(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -242,7 +244,7 @@ async def verify_connection(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -329,7 +331,7 @@ def merge_ollama_models_lists(model_lists):
return list(merged_models.values())
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
@@ -462,7 +464,7 @@ async def get_ollama_tags(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -634,7 +636,10 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
class ModelNameForm(BaseModel):
- name: str
+ model: Optional[str] = None
+ model_config = ConfigDict(
+ extra="allow",
+ )
@router.post("/api/unload")
@@ -643,10 +648,12 @@ async def unload_model(
form_data: ModelNameForm,
user=Depends(get_admin_user),
):
- model_name = form_data.name
+ form_data = form_data.model_dump(exclude_none=True)
+ model_name = form_data.get("model", form_data.get("name"))
+
if not model_name:
raise HTTPException(
- status_code=400, detail="Missing 'name' of model to unload."
+ status_code=400, detail="Missing name of the model to unload."
)
# Refresh/load models if needed, get mapping from name to URLs
@@ -709,11 +716,14 @@ async def pull_model(
url_idx: int = 0,
user=Depends(get_admin_user),
):
+ form_data = form_data.model_dump(exclude_none=True)
+ form_data["model"] = form_data.get("model", form_data.get("name"))
+
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
# Admin should be able to pull models from any source
- payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
+ payload = {**form_data, "insecure": True}
return await send_post_request(
url=f"{url}/api/pull",
@@ -724,7 +734,7 @@ async def pull_model(
class PushModelForm(BaseModel):
- name: str
+ model: str
insecure: Optional[bool] = None
stream: Optional[bool] = None
@@ -741,12 +751,12 @@ async def push_model(
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
- if form_data.name in models:
- url_idx = models[form_data.name]["urls"][0]
+ if form_data.model in models:
+ url_idx = models[form_data.model]["urls"][0]
else:
raise HTTPException(
status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -824,7 +834,7 @@ async def copy_model(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -865,16 +875,21 @@ async def delete_model(
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
+ form_data = form_data.model_dump(exclude_none=True)
+ form_data["model"] = form_data.get("model", form_data.get("name"))
+
+ model = form_data.get("model")
+
if url_idx is None:
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
- if form_data.name in models:
- url_idx = models[form_data.name]["urls"][0]
+ if model in models:
+ url_idx = models[model]["urls"][0]
else:
raise HTTPException(
status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -884,13 +899,13 @@ async def delete_model(
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
- data=form_data.model_dump_json(exclude_none=True).encode(),
+ data=json.dumps(form_data).encode(),
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -926,16 +941,21 @@ async def delete_model(
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
+ form_data = form_data.model_dump(exclude_none=True)
+ form_data["model"] = form_data.get("model", form_data.get("name"))
+
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
- if form_data.name not in models:
+ model = form_data.get("model")
+
+ if model not in models:
raise HTTPException(
status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
- url_idx = random.choice(models[form_data.name]["urls"])
+ url_idx = random.choice(models[model]["urls"])
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -949,7 +969,7 @@ async def show_model_info(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -958,7 +978,7 @@ async def show_model_info(
else {}
),
},
- data=form_data.model_dump_json(exclude_none=True).encode(),
+ data=json.dumps(form_data).encode(),
)
r.raise_for_status()
@@ -1036,7 +1056,7 @@ async def embed(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -1123,7 +1143,7 @@ async def embeddings(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index 7649271fee..a759ec7eee 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -8,7 +8,7 @@ from typing import Literal, Optional, overload
import aiohttp
from aiocache import cached
import requests
-
+from urllib.parse import quote
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
@@ -21,6 +21,7 @@ from open_webui.config import (
CACHE_DIR,
)
from open_webui.env import (
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -66,7 +67,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -225,7 +226,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -386,7 +387,7 @@ async def get_filtered_models(models, user):
return filtered_models
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
@@ -478,7 +479,7 @@ async def get_models(
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -573,7 +574,7 @@ async def verify_connection(
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -633,13 +634,7 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail)
-def convert_to_azure_payload(
- url,
- payload: dict,
-):
- model = payload.get("model", "")
-
- # Filter allowed parameters based on Azure OpenAI API
+def get_azure_allowed_params(api_version: str) -> set[str]:
allowed_params = {
"messages",
"temperature",
@@ -669,6 +664,23 @@ def convert_to_azure_payload(
"max_completion_tokens",
}
+ try:
+ if api_version >= "2024-09-01-preview":
+ allowed_params.add("stream_options")
+ except ValueError:
+ log.debug(
+ f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters."
+ )
+
+ return allowed_params
+
+
+def convert_to_azure_payload(url, payload: dict, api_version: str):
+ model = payload.get("model", "")
+
+ # Filter allowed parameters based on Azure OpenAI API
+ allowed_params = get_azure_allowed_params(api_version)
+
# Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"):
# Convert max_tokens to max_completion_tokens for o-series models
@@ -806,7 +818,7 @@ async def generate_chat_completion(
),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -817,8 +829,8 @@ async def generate_chat_completion(
}
if api_config.get("azure", False):
- request_url, payload = convert_to_azure_payload(url, payload)
- api_version = api_config.get("api_version", "") or "2023-03-15-preview"
+ api_version = api_config.get("api_version", "2023-03-15-preview")
+ request_url, payload = convert_to_azure_payload(url, payload, api_version)
headers["api-key"] = key
headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}"
@@ -924,7 +936,7 @@ async def embeddings(request: Request, form_data: dict, user):
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -996,7 +1008,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -1007,16 +1019,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
}
if api_config.get("azure", False):
+ api_version = api_config.get("api_version", "2023-03-15-preview")
headers["api-key"] = key
- headers["api-version"] = (
- api_config.get("api_version", "") or "2023-03-15-preview"
- )
+ headers["api-version"] = api_version
payload = json.loads(body)
- url, payload = convert_to_azure_payload(url, payload)
+ url, payload = convert_to_azure_payload(url, payload, api_version)
body = json.dumps(payload).encode()
- request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
+ request_url = f"{url}/{path}?api-version={api_version}"
else:
headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}"
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index ee6f99fbb5..97db0a72f7 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -29,6 +29,7 @@ import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
+from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.documents import Document
from open_webui.models.files import FileModel, Files
@@ -69,6 +70,7 @@ from open_webui.retrieval.web.external import search_external
from open_webui.retrieval.utils import (
get_embedding_function,
+ get_reranking_function,
get_model_path,
query_collection,
query_collection_with_hybrid_search,
@@ -823,6 +825,12 @@ async def update_rag_config(
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
True,
)
+
+ request.app.state.RERANKING_FUNCTION = get_reranking_function(
+ request.app.state.config.RAG_RERANKING_ENGINE,
+ request.app.state.config.RAG_RERANKING_MODEL,
+ request.app.state.rf,
+ )
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
@@ -1146,6 +1154,7 @@ def save_docs_to_vector_db(
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
+ docs = text_splitter.split_documents(docs)
elif request.app.state.config.TEXT_SPLITTER == "token":
log.info(
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
@@ -1158,11 +1167,56 @@ def save_docs_to_vector_db(
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
+ docs = text_splitter.split_documents(docs)
+ elif request.app.state.config.TEXT_SPLITTER == "markdown_header":
+ log.info("Using markdown header text splitter")
+
+ # Define headers to split on - covering most common markdown header levels
+ headers_to_split_on = [
+ ("#", "Header 1"),
+ ("##", "Header 2"),
+ ("###", "Header 3"),
+ ("####", "Header 4"),
+ ("#####", "Header 5"),
+ ("######", "Header 6"),
+ ]
+
+ markdown_splitter = MarkdownHeaderTextSplitter(
+ headers_to_split_on=headers_to_split_on,
+ strip_headers=False, # Keep headers in content for context
+ )
+
+ md_split_docs = []
+ for doc in docs:
+ md_header_splits = markdown_splitter.split_text(doc.page_content)
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=request.app.state.config.CHUNK_SIZE,
+ chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
+ add_start_index=True,
+ )
+ md_header_splits = text_splitter.split_documents(md_header_splits)
+
+ # Convert back to Document objects, preserving original metadata
+ for split_chunk in md_header_splits:
+ headings_list = []
+ # Extract header values in order based on headers_to_split_on
+ for _, header_meta_key_name in headers_to_split_on:
+ if header_meta_key_name in split_chunk.metadata:
+ headings_list.append(
+ split_chunk.metadata[header_meta_key_name]
+ )
+
+ md_split_docs.append(
+ Document(
+ page_content=split_chunk.page_content,
+ metadata={**doc.metadata, "headings": headings_list},
+ )
+ )
+
+ docs = md_split_docs
else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
- docs = text_splitter.split_documents(docs)
-
if len(docs) == 0:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@@ -1747,6 +1801,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
+ elif engine == "exa":
+ if request.app.state.config.EXA_API_KEY:
+ return search_exa(
+ request.app.state.config.EXA_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ else:
+ raise Exception("No EXA_API_KEY found in environment variables")
elif engine == "searchapi":
if request.app.state.config.SEARCHAPI_API_KEY:
return search_searchapi(
@@ -1784,6 +1848,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
+ elif engine == "exa":
+ return search_exa(
+ request.app.state.config.EXA_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
@@ -1978,7 +2049,13 @@ def query_doc_handler(
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
- reranking_function=request.app.state.rf,
+ reranking_function=(
+ lambda sentences: (
+ request.app.state.RERANKING_FUNCTION(sentences, user=user)
+ if request.app.state.RERANKING_FUNCTION
+ else None
+ )
+ ),
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
@@ -2035,7 +2112,9 @@ def query_collection_handler(
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
- reranking_function=request.app.state.rf,
+ reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
+ sentences, user=user
+ ),
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py
index f726368eba..41415bff04 100644
--- a/backend/open_webui/routers/tools.py
+++ b/backend/open_webui/routers/tools.py
@@ -153,7 +153,7 @@ async def load_tool_from_url(
)
try:
- async with aiohttp.ClientSession() as session:
+ async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py
index 35e40dccb2..ceeecba8c3 100644
--- a/backend/open_webui/socket/main.py
+++ b/backend/open_webui/socket/main.py
@@ -1,13 +1,18 @@
import asyncio
+import random
+
import socketio
import logging
import sys
import time
+from typing import Dict, Set
from redis import asyncio as aioredis
+import pycrdt as Y
from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats
+from open_webui.models.notes import Notes, NoteUpdateForm
from open_webui.utils.redis import (
get_sentinels_from_env,
get_sentinel_url_from_env,
@@ -22,7 +27,11 @@ from open_webui.env import (
WEBSOCKET_SENTINEL_HOSTS,
)
from open_webui.utils.auth import decode_token
-from open_webui.socket.utils import RedisDict, RedisLock
+from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
+from open_webui.tasks import create_task, stop_item_tasks
+from open_webui.utils.redis import get_redis_connection
+from open_webui.utils.access_control import has_access, get_users_with_access
+
from open_webui.env import (
GLOBAL_LOG_LEVEL,
@@ -35,6 +44,8 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
+REDIS = None
+
if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS:
mgr = socketio.AsyncRedisManager(
@@ -69,6 +80,14 @@ TIMEOUT_DURATION = 3
if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.")
+ REDIS = get_redis_connection(
+ redis_url=WEBSOCKET_REDIS_URL,
+ redis_sentinels=get_sentinels_from_env(
+ WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
+ ),
+ async_mode=True,
+ )
+
redis_sentinels = get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
@@ -101,14 +120,37 @@ else:
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
+
aquire_func = release_func = renew_func = lambda: True
+YDOC_MANAGER = YdocManager(
+ redis=REDIS,
+ redis_key_prefix="open-webui:ydoc:documents",
+)
+
+
async def periodic_usage_pool_cleanup():
- if not aquire_func():
- log.debug("Usage pool cleanup lock already exists. Not running it.")
- return
- log.debug("Running periodic_usage_pool_cleanup")
+ max_retries = 2
+ retry_delay = random.uniform(
+ WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
+ )
+ for attempt in range(max_retries + 1):
+ if aquire_func():
+ break
+ else:
+ if attempt < max_retries:
+ log.debug(
+ f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
+ )
+ await asyncio.sleep(retry_delay)
+ else:
+ log.warning(
+ "Failed to acquire cleanup lock after retries. Skipping cleanup."
+ )
+ return
+
+ log.debug("Running periodic_cleanup")
try:
while True:
if not renew_func():
@@ -169,16 +211,20 @@ def get_user_id_from_session_pool(sid):
return None
-def get_user_ids_from_room(room):
+def get_session_ids_from_room(room):
+ """Get all session IDs from a specific room."""
active_session_ids = sio.manager.get_participants(
namespace="/",
room=room,
)
+ return [session_id[0] for session_id in active_session_ids]
+
+
+def get_user_ids_from_room(room):
+ active_session_ids = get_session_ids_from_room(room)
active_user_ids = list(
- set(
- [SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids]
- )
+ set([SESSION_POOL.get(session_id)["id"] for session_id in active_session_ids])
)
return active_user_ids
@@ -298,6 +344,241 @@ async def channel_events(sid, data):
)
+@sio.on("ydoc:document:join")
+async def ydoc_document_join(sid, data):
+ """Handle user joining a document"""
+ user = SESSION_POOL.get(sid)
+
+ try:
+ document_id = data["document_id"]
+
+ if document_id.startswith("note:"):
+ note_id = document_id.split(":")[1]
+ note = Notes.get_note_by_id(note_id)
+ if not note:
+ log.error(f"Note {note_id} not found")
+ return
+
+ if (
+ user.get("role") != "admin"
+ and user.get("id") != note.user_id
+ and not has_access(
+ user.get("id"), type="read", access_control=note.access_control
+ )
+ ):
+ log.error(
+ f"User {user.get('id')} does not have access to note {note_id}"
+ )
+ return
+
+ user_id = data.get("user_id", sid)
+ user_name = data.get("user_name", "Anonymous")
+ user_color = data.get("user_color", "#000000")
+
+ log.info(f"User {user_id} joining document {document_id}")
+ await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid)
+
+ # Join Socket.IO room
+ await sio.enter_room(sid, f"doc_{document_id}")
+
+ active_session_ids = get_session_ids_from_room(f"doc_{document_id}")
+
+ # Get the Yjs document state
+ ydoc = Y.Doc()
+ updates = await YDOC_MANAGER.get_updates(document_id)
+ for update in updates:
+ ydoc.apply_update(bytes(update))
+
+ # Encode the entire document state as an update
+ state_update = ydoc.get_update()
+ await sio.emit(
+ "ydoc:document:state",
+ {
+ "document_id": document_id,
+ "state": list(state_update), # Convert bytes to list for JSON
+ "sessions": active_session_ids,
+ },
+ room=sid,
+ )
+
+ # Notify other users about the new user
+ await sio.emit(
+ "ydoc:user:joined",
+ {
+ "document_id": document_id,
+ "user_id": user_id,
+ "user_name": user_name,
+ "user_color": user_color,
+ },
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ log.info(f"User {user_id} successfully joined document {document_id}")
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_join: {e}")
+ await sio.emit("error", {"message": "Failed to join document"}, room=sid)
+
+
+async def document_save_handler(document_id, data, user):
+ if document_id.startswith("note:"):
+ note_id = document_id.split(":")[1]
+ note = Notes.get_note_by_id(note_id)
+ if not note:
+ log.error(f"Note {note_id} not found")
+ return
+
+ if (
+ user.get("role") != "admin"
+ and user.get("id") != note.user_id
+ and not has_access(
+ user.get("id"), type="read", access_control=note.access_control
+ )
+ ):
+ log.error(f"User {user.get('id')} does not have access to note {note_id}")
+ return
+
+ Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
+
+
+@sio.on("ydoc:document:state")
+async def yjs_document_state(sid, data):
+ """Send the current state of the Yjs document to the user"""
+ try:
+ document_id = data["document_id"]
+ room = f"doc_{document_id}"
+
+ active_session_ids = get_session_ids_from_room(room)
+ print(active_session_ids)
+ if sid not in active_session_ids:
+ log.warning(f"Session {sid} not in room {room}. Cannot send state.")
+ return
+
+ if not await YDOC_MANAGER.document_exists(document_id):
+ log.warning(f"Document {document_id} not found")
+ return
+
+ # Get the Yjs document state
+ ydoc = Y.Doc()
+ updates = await YDOC_MANAGER.get_updates(document_id)
+ for update in updates:
+ ydoc.apply_update(bytes(update))
+
+ # Encode the entire document state as an update
+ state_update = ydoc.get_update()
+
+ await sio.emit(
+ "ydoc:document:state",
+ {
+ "document_id": document_id,
+ "state": list(state_update), # Convert bytes to list for JSON
+ "sessions": active_session_ids,
+ },
+ room=sid,
+ )
+ except Exception as e:
+ log.error(f"Error in yjs_document_state: {e}")
+
+
+@sio.on("ydoc:document:update")
+async def yjs_document_update(sid, data):
+ """Handle Yjs document updates"""
+ try:
+ document_id = data["document_id"]
+
+ try:
+ await stop_item_tasks(REDIS, document_id)
+ except:
+ pass
+
+ user_id = data.get("user_id", sid)
+
+ update = data["update"] # List of bytes from frontend
+
+ await YDOC_MANAGER.append_to_updates(
+ document_id=document_id,
+ update=update, # Convert list of bytes to bytes
+ )
+
+ # Broadcast update to all other users in the document
+ await sio.emit(
+ "ydoc:document:update",
+ {
+ "document_id": document_id,
+ "user_id": user_id,
+ "update": update,
+ "socket_id": sid, # Add socket_id to match frontend filtering
+ },
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ async def debounced_save():
+ await asyncio.sleep(0.5)
+ await document_save_handler(
+ document_id, data.get("data", {}), SESSION_POOL.get(sid)
+ )
+
+ await create_task(REDIS, debounced_save(), document_id)
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_update: {e}")
+
+
+@sio.on("ydoc:document:leave")
+async def yjs_document_leave(sid, data):
+ """Handle user leaving a document"""
+ try:
+ document_id = data["document_id"]
+ user_id = data.get("user_id", sid)
+
+ log.info(f"User {user_id} leaving document {document_id}")
+
+ # Remove user from the document
+ await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid)
+
+ # Leave Socket.IO room
+ await sio.leave_room(sid, f"doc_{document_id}")
+
+ # Notify other users
+ await sio.emit(
+ "ydoc:user:left",
+ {"document_id": document_id, "user_id": user_id},
+ room=f"doc_{document_id}",
+ )
+
+ if (
+ YDOC_MANAGER.document_exists(document_id)
+ and len(await YDOC_MANAGER.get_users(document_id)) == 0
+ ):
+ log.info(f"Cleaning up document {document_id} as no users are left")
+ await YDOC_MANAGER.clear_document(document_id)
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_leave: {e}")
+
+
+@sio.on("ydoc:awareness:update")
+async def yjs_awareness_update(sid, data):
+ """Handle awareness updates (cursors, selections, etc.)"""
+ try:
+ document_id = data["document_id"]
+ user_id = data.get("user_id", sid)
+ update = data["update"]
+
+ # Broadcast awareness update to all other users in the document
+ await sio.emit(
+ "ydoc:awareness:update",
+ {"document_id": document_id, "user_id": user_id, "update": update},
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ except Exception as e:
+ log.error(f"Error in yjs_awareness_update: {e}")
+
+
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
@@ -309,6 +590,8 @@ async def disconnect(sid):
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
+
+ await YDOC_MANAGER.remove_user_from_all_documents(sid)
else:
pass
# print(f"Unknown session ID {sid} disconnected")
diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py
index 85a8bb7909..a422d76207 100644
--- a/backend/open_webui/socket/utils.py
+++ b/backend/open_webui/socket/utils.py
@@ -1,6 +1,8 @@
import json
import uuid
from open_webui.utils.redis import get_redis_connection
+from typing import Optional, List, Tuple
+import pycrdt as Y
class RedisLock:
@@ -89,3 +91,109 @@ class RedisDict:
if key not in self:
self[key] = default
return self[key]
+
+
+class YdocManager:
+ def __init__(
+ self,
+ redis=None,
+ redis_key_prefix: str = "open-webui:ydoc:documents",
+ ):
+ self._updates = {}
+ self._users = {}
+ self._redis = redis
+ self._redis_key_prefix = redis_key_prefix
+
+ async def append_to_updates(self, document_id: str, update: bytes):
+ document_id = document_id.replace(":", "_")
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ await self._redis.rpush(redis_key, json.dumps(list(update)))
+ else:
+ if document_id not in self._updates:
+ self._updates[document_id] = []
+ self._updates[document_id].append(update)
+
+ async def get_updates(self, document_id: str) -> List[bytes]:
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ updates = await self._redis.lrange(redis_key, 0, -1)
+ return [bytes(json.loads(update)) for update in updates]
+ else:
+ return self._updates.get(document_id, [])
+
+ async def document_exists(self, document_id: str) -> bool:
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ return await self._redis.exists(redis_key) > 0
+ else:
+ return document_id in self._updates
+
+ async def get_users(self, document_id: str) -> List[str]:
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+ users = await self._redis.smembers(redis_key)
+ return list(users)
+ else:
+ return self._users.get(document_id, [])
+
+ async def add_user(self, document_id: str, user_id: str):
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+ await self._redis.sadd(redis_key, user_id)
+ else:
+ if document_id not in self._users:
+ self._users[document_id] = set()
+ self._users[document_id].add(user_id)
+
+ async def remove_user(self, document_id: str, user_id: str):
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+ await self._redis.srem(redis_key, user_id)
+ else:
+ if document_id in self._users and user_id in self._users[document_id]:
+ self._users[document_id].remove(user_id)
+
+ async def remove_user_from_all_documents(self, user_id: str):
+ if self._redis:
+ keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
+ for key in keys:
+ if key.endswith(":users"):
+ await self._redis.srem(key, user_id)
+
+ document_id = key.split(":")[-2]
+ if len(await self.get_users(document_id)) == 0:
+ await self.clear_document(document_id)
+
+ else:
+ for document_id in list(self._users.keys()):
+ if user_id in self._users[document_id]:
+ self._users[document_id].remove(user_id)
+ if not self._users[document_id]:
+ del self._users[document_id]
+
+ await self.clear_document(document_id)
+
+ async def clear_document(self, document_id: str):
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ await self._redis.delete(redis_key)
+ redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
+ await self._redis.delete(redis_users_key)
+ else:
+ if document_id in self._updates:
+ del self._updates[document_id]
+ if document_id in self._users:
+ del self._users[document_id]
diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py
index 2d3955f0a2..a4132d9cf6 100644
--- a/backend/open_webui/tasks.py
+++ b/backend/open_webui/tasks.py
@@ -3,25 +3,27 @@ import asyncio
from typing import Dict
from uuid import uuid4
import json
+import logging
from redis.asyncio import Redis
from fastapi import Request
from typing import Dict, List, Optional
+from open_webui.env import SRC_LOG_LEVELS
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
-chat_tasks = {}
+item_tasks = {}
REDIS_TASKS_KEY = "open-webui:tasks"
-REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
+REDIS_ITEM_TASKS_KEY = "open-webui:tasks:item"
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
-def is_redis(request: Request) -> bool:
- # Called everywhere a request is available to check Redis
- return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
-
-
async def redis_task_command_listener(app):
redis: Redis = app.state.redis
pubsub = redis.pubsub()
@@ -38,7 +40,7 @@ async def redis_task_command_listener(app):
if local_task:
local_task.cancel()
except Exception as e:
- print(f"Error handling distributed task command: {e}")
+ log.exception(f"Error handling distributed task command: {e}")
### ------------------------------
@@ -46,21 +48,21 @@ async def redis_task_command_listener(app):
### ------------------------------
-async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+async def redis_save_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline()
- pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
- if chat_id:
- pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
+ pipe.hset(REDIS_TASKS_KEY, task_id, item_id or "")
+ if item_id:
+ pipe.sadd(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
await pipe.execute()
-async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+async def redis_cleanup_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id)
- if chat_id:
- pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
- if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
- pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
+ if item_id:
+ pipe.srem(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
+ if (await pipe.scard(f"{REDIS_ITEM_TASKS_KEY}:{item_id}").execute())[-1] == 0:
+ pipe.delete(f"{REDIS_ITEM_TASKS_KEY}:{item_id}") # Remove if empty set
await pipe.execute()
@@ -68,31 +70,31 @@ async def redis_list_tasks(redis: Redis) -> List[str]:
return list(await redis.hkeys(REDIS_TASKS_KEY))
-async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
- return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
+async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]:
+ return list(await redis.smembers(f"{REDIS_ITEM_TASKS_KEY}:{item_id}"))
async def redis_send_command(redis: Redis, command: dict):
await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
-async def cleanup_task(request, task_id: str, id=None):
+async def cleanup_task(redis, task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
- if is_redis(request):
- await redis_cleanup_task(request.app.state.redis, task_id, id)
+ if redis:
+ await redis_cleanup_task(redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists
- # If an ID is provided, remove the task from the chat_tasks dictionary
- if id and task_id in chat_tasks.get(id, []):
- chat_tasks[id].remove(task_id)
- if not chat_tasks[id]: # If no tasks left for this ID, remove the entry
- chat_tasks.pop(id, None)
+ # If an ID is provided, remove the task from the item_tasks dictionary
+ if id and task_id in item_tasks.get(id, []):
+ item_tasks[id].remove(task_id)
+ if not item_tasks[id]: # If no tasks left for this ID, remove the entry
+ item_tasks.pop(id, None)
-async def create_task(request, coroutine, id=None):
+async def create_task(redis, coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -101,48 +103,48 @@ async def create_task(request, coroutine, id=None):
# Add a done callback for cleanup
task.add_done_callback(
- lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
+ lambda t: asyncio.create_task(cleanup_task(redis, task_id, id))
)
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
- if chat_tasks.get(id):
- chat_tasks[id].append(task_id)
+ if item_tasks.get(id):
+ item_tasks[id].append(task_id)
else:
- chat_tasks[id] = [task_id]
+ item_tasks[id] = [task_id]
- if is_redis(request):
- await redis_save_task(request.app.state.redis, task_id, id)
+ if redis:
+ await redis_save_task(redis, task_id, id)
return task_id, task
-async def list_tasks(request):
+async def list_tasks(redis):
"""
List all currently active task IDs.
"""
- if is_redis(request):
- return await redis_list_tasks(request.app.state.redis)
+ if redis:
+ return await redis_list_tasks(redis)
return list(tasks.keys())
-async def list_task_ids_by_chat_id(request, id):
+async def list_task_ids_by_item_id(redis, id):
"""
List all tasks associated with a specific ID.
"""
- if is_redis(request):
- return await redis_list_chat_tasks(request.app.state.redis, id)
- return chat_tasks.get(id, [])
+ if redis:
+ return await redis_list_item_tasks(redis, id)
+ return item_tasks.get(id, [])
-async def stop_task(request, task_id: str):
+async def stop_task(redis, task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
- if is_redis(request):
+ if redis:
# PUBSUB: All instances check if they have this task, and stop if so.
await redis_send_command(
- request.app.state.redis,
+ redis,
{
"action": "stop",
"task_id": task_id,
@@ -151,7 +153,7 @@ async def stop_task(request, task_id: str):
# Optionally check if task_id still in Redis a few moments later for feedback?
return {"status": True, "message": f"Stop signal sent for {task_id}"}
- task = tasks.get(task_id)
+ task = tasks.pop(task_id)
if not task:
raise ValueError(f"Task with ID {task_id} not found.")
@@ -160,7 +162,22 @@ async def stop_task(request, task_id: str):
await task # Wait for the task to handle the cancellation
except asyncio.CancelledError:
# Task successfully canceled
- tasks.pop(task_id, None) # Remove it from the dictionary
return {"status": True, "message": f"Task {task_id} successfully stopped."}
return {"status": False, "message": f"Failed to stop task {task_id}."}
+
+
+async def stop_item_tasks(redis: Redis, item_id: str):
+ """
+ Stop all tasks associated with a specific item ID.
+ """
+ task_ids = await list_task_ids_by_item_id(redis, item_id)
+ if not task_ids:
+ return {"status": True, "message": f"No tasks found for item {item_id}."}
+
+ for task_id in task_ids:
+ result = await stop_task(redis, task_id)
+ if not result["status"]:
+ return result # Return the first failure
+
+ return {"status": True, "message": f"All tasks for item {item_id} stopped."}
diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py
index 9befaf2a91..3262c803f3 100644
--- a/backend/open_webui/utils/auth.py
+++ b/backend/open_webui/utils/auth.py
@@ -74,31 +74,37 @@ def override_static(path: str, content: str):
def get_license_data(app, key):
- if key:
- try:
- res = requests.post(
- "https://api.openwebui.com/api/v1/license/",
- json={"key": key, "version": "1"},
- timeout=5,
+ def handler(u):
+ res = requests.post(
+ f"{u}/api/v1/license/",
+ json={"key": key, "version": "1"},
+ timeout=5,
+ )
+
+ if getattr(res, "ok", False):
+ payload = getattr(res, "json", lambda: {})()
+ for k, v in payload.items():
+ if k == "resources":
+ for p, c in v.items():
+ globals().get("override_static", lambda a, b: None)(p, c)
+ elif k == "count":
+ setattr(app.state, "USER_COUNT", v)
+ elif k == "name":
+ setattr(app.state, "WEBUI_NAME", v)
+ elif k == "metadata":
+ setattr(app.state, "LICENSE_METADATA", v)
+ return True
+ else:
+ log.error(
+ f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
- if getattr(res, "ok", False):
- payload = getattr(res, "json", lambda: {})()
- for k, v in payload.items():
- if k == "resources":
- for p, c in v.items():
- globals().get("override_static", lambda a, b: None)(p, c)
- elif k == "count":
- setattr(app.state, "USER_COUNT", v)
- elif k == "name":
- setattr(app.state, "WEBUI_NAME", v)
- elif k == "metadata":
- setattr(app.state, "LICENSE_METADATA", v)
- return True
- else:
- log.error(
- f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
- )
+ if key:
+ us = ["https://api.openwebui.com", "https://licenses.api.openwebui.com"]
+ try:
+ for u in us:
+ if handler(u):
+ return True
except Exception as ex:
log.exception(f"License: Uncaught Exception: {ex}")
return False
diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py
index 268c910e3e..83483f391b 100644
--- a/backend/open_webui/utils/chat.py
+++ b/backend/open_webui/utils/chat.py
@@ -419,7 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
params[key] = value
if "__user__" in sig.parameters:
- __user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
+ __user__ = user.model_dump() if isinstance(user, UserModel) else {}
try:
if hasattr(function_module, "UserValves"):
diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py
index 2557610060..a21df2756b 100644
--- a/backend/open_webui/utils/logger.py
+++ b/backend/open_webui/utils/logger.py
@@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
from loguru import logger
+
from open_webui.env import (
+ AUDIT_UVICORN_LOGGER_NAMES,
AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOG_LEVEL,
AUDIT_LOGS_FILE_PATH,
@@ -128,11 +130,13 @@ def start_logger():
logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
)
+
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = []
- for uvicorn_logger_name in ["uvicorn.access"]:
+
+ for uvicorn_logger_name in AUDIT_UVICORN_LOGGER_NAMES:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()]
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index b1e69db264..fc21543457 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -23,6 +23,7 @@ from starlette.responses import Response, StreamingResponse
from open_webui.models.chats import Chats
+from open_webui.models.folders import Folders
from open_webui.models.users import Users
from open_webui.socket.main import (
get_event_call,
@@ -56,7 +57,7 @@ from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
-from open_webui.retrieval.utils import get_sources_from_files
+from open_webui.retrieval.utils import get_sources_from_items
from open_webui.utils.chat import generate_chat_completion
@@ -248,30 +249,28 @@ async def chat_completion_tools_handler(
if tool_id
else f"{tool_function_name}"
)
- if tool.get("metadata", {}).get("citation", False) or tool.get(
- "direct", False
- ):
- # Citation is enabled for this tool
- sources.append(
- {
- "source": {
- "name": (f"TOOL:{tool_name}"),
- },
- "document": [tool_result],
- "metadata": [
- {
- "source": (f"TOOL:{tool_name}"),
- "parameters": tool_function_params,
- }
- ],
- }
- )
- else:
- # Citation is not enabled for this tool
- body["messages"] = add_or_update_user_message(
- f"\nTool `{tool_name}` Output: {tool_result}",
- body["messages"],
- )
+
+ # Citation is enabled for this tool
+ sources.append(
+ {
+ "source": {
+ "name": (f"TOOL:{tool_name}"),
+ },
+ "document": [tool_result],
+ "metadata": [
+ {
+ "source": (f"TOOL:{tool_name}"),
+ "parameters": tool_function_params,
+ }
+ ],
+ "tool_result": True,
+ }
+ )
+ # Citation is not enabled for this tool
+ body["messages"] = add_or_update_user_message(
+ f"\nTool `{tool_name}` Output: {tool_result}",
+ body["messages"],
+ )
if (
tools[tool_function_name]
@@ -640,25 +639,34 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])]
try:
- # Offload get_sources_from_files to a separate thread
+ # Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
sources = await loop.run_in_executor(
executor,
- lambda: get_sources_from_files(
+ lambda: get_sources_from_items(
request=request,
- files=files,
+ items=files,
queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user
),
k=request.app.state.config.TOP_K,
- reranking_function=request.app.state.rf,
+ reranking_function=(
+ lambda sentences: (
+ request.app.state.RERANKING_FUNCTION(
+ sentences, user=user
+ )
+ if request.app.state.RERANKING_FUNCTION
+ else None
+ )
+ ),
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT,
+ user=user,
),
)
except Exception as e:
@@ -718,6 +726,10 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, user, metadata, model):
+ # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
+ # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
+ # -> Chat Files
+
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@@ -752,6 +764,26 @@ async def process_chat_payload(request, form_data, user, metadata, model):
events = []
sources = []
+ # Folder "Project" handling
+ # Check if the request has chat_id and is inside of a folder
+ chat_id = metadata.get("chat_id", None)
+ if chat_id and user:
+ chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
+ if chat and chat.folder_id:
+ folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
+
+ if folder and folder.data:
+ if "system_prompt" in folder.data:
+ form_data["messages"] = add_or_update_system_message(
+ folder.data["system_prompt"], form_data["messages"]
+ )
+ if "files" in folder.data:
+ form_data["files"] = [
+ *folder.data["files"],
+ *form_data.get("files", []),
+ ]
+
+ # Model "Knowledge" handling
user_message = get_last_user_message(form_data["messages"])
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
@@ -804,7 +836,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise e
try:
-
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
@@ -912,7 +943,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
-
except Exception as e:
log.exception(e)
@@ -925,55 +955,59 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
- citation_idx = {}
+ citation_idx_map = {}
+
for source in sources:
- if "document" in source:
- for doc_context, doc_meta in zip(
+ is_tool_result = source.get("tool_result", False)
+
+ if "document" in source and not is_tool_result:
+ for document_text, document_metadata in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
- citation_id = (
- doc_meta.get("source", None)
+ source_id = (
+ document_metadata.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
- if citation_id not in citation_idx:
- citation_idx[citation_id] = len(citation_idx) + 1
+
+ if source_id not in citation_idx_map:
+ citation_idx_map[source_id] = len(citation_idx_map) + 1
+
context_string += (
- f'Tool Executed
\nTool Executed
\nExecuting...
\n
-