diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 3611659ca3..813eb400c1 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -13,12 +13,15 @@ from urllib.parse import urlparse
import requests
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
+from authlib.integrations.starlette_client import OAuth
+
from open_webui.env import (
DATA_DIR,
DATABASE_URL,
ENV,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
FRONTEND_BUILD_DIR,
@@ -211,11 +214,16 @@ class PersistentConfig(Generic[T]):
class AppConfig:
_state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None
+ _redis_key_prefix: str
def __init__(
- self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = []
+ self,
+ redis_url: Optional[str] = None,
+ redis_sentinels: Optional[list] = [],
+ redis_key_prefix: str = "open-webui",
):
super().__setattr__("_state", {})
+ super().__setattr__("_redis_key_prefix", redis_key_prefix)
if redis_url:
super().__setattr__(
"_redis",
@@ -230,7 +238,7 @@ class AppConfig:
self._state[key].save()
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
self._redis.set(redis_key, json.dumps(self._state[key].value))
def __getattr__(self, key):
@@ -239,7 +247,7 @@ class AppConfig:
# If Redis is available, check for an updated value
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
redis_value = self._redis.get(redis_key)
if redis_value is not None:
@@ -431,6 +439,12 @@ OAUTH_SCOPES = PersistentConfig(
os.environ.get("OAUTH_SCOPES", "openid email profile"),
)
+OAUTH_TIMEOUT = PersistentConfig(
+ "OAUTH_TIMEOUT",
+ "oauth.oidc.oauth_timeout",
+ os.environ.get("OAUTH_TIMEOUT", ""),
+)
+
OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig(
"OAUTH_CODE_CHALLENGE_METHOD",
"oauth.oidc.code_challenge_method",
@@ -534,13 +548,20 @@ def load_oauth_providers():
OAUTH_PROVIDERS.clear()
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
- def google_oauth_register(client):
+ def google_oauth_register(client: OAuth):
client.register(
name="google",
client_id=GOOGLE_CLIENT_ID.value,
client_secret=GOOGLE_CLIENT_SECRET.value,
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
- client_kwargs={"scope": GOOGLE_OAUTH_SCOPE.value},
+ client_kwargs={
+ "scope": GOOGLE_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GOOGLE_REDIRECT_URI.value,
)
@@ -555,7 +576,7 @@ def load_oauth_providers():
and MICROSOFT_CLIENT_TENANT_ID.value
):
- def microsoft_oauth_register(client):
+ def microsoft_oauth_register(client: OAuth):
client.register(
name="microsoft",
client_id=MICROSOFT_CLIENT_ID.value,
@@ -563,6 +584,11 @@ def load_oauth_providers():
server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
client_kwargs={
"scope": MICROSOFT_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
},
redirect_uri=MICROSOFT_REDIRECT_URI.value,
)
@@ -575,7 +601,7 @@ def load_oauth_providers():
if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value:
- def github_oauth_register(client):
+ def github_oauth_register(client: OAuth):
client.register(
name="github",
client_id=GITHUB_CLIENT_ID.value,
@@ -584,7 +610,14 @@ def load_oauth_providers():
authorize_url="https://github.com/login/oauth/authorize",
api_base_url="https://api.github.com",
userinfo_endpoint="https://api.github.com/user",
- client_kwargs={"scope": GITHUB_CLIENT_SCOPE.value},
+ client_kwargs={
+ "scope": GITHUB_CLIENT_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value,
)
@@ -600,9 +633,12 @@ def load_oauth_providers():
and OPENID_PROVIDER_URL.value
):
- def oidc_oauth_register(client):
+ def oidc_oauth_register(client: OAuth):
client_kwargs = {
"scope": OAUTH_SCOPES.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}
+ ),
}
if (
@@ -911,6 +947,18 @@ except Exception:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
+
+####################################
+# MODELS
+####################################
+
+ENABLE_BASE_MODELS_CACHE = PersistentConfig(
+ "ENABLE_BASE_MODELS_CACHE",
+ "models.base_models_cache",
+ os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true",
+)
+
+
####################################
# TOOL_SERVERS
####################################
@@ -1810,11 +1858,12 @@ MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128"))
QDRANT_URI = os.environ.get("QDRANT_URI", None)
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
-QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true"
+QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
ENABLE_QDRANT_MULTITENANCY_MODE = (
- os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "false").lower() == "true"
+ os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
)
+QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index 0f7b5611f5..4db919121a 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -199,6 +199,7 @@ CHANGELOG = changelog_json
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
+
####################################
# ENABLE_FORWARD_USER_INFO_HEADERS
####################################
@@ -266,21 +267,43 @@ else:
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
+DATABASE_TYPE = os.environ.get("DATABASE_TYPE")
+DATABASE_USER = os.environ.get("DATABASE_USER")
+DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD")
+
+DATABASE_CRED = ""
+if DATABASE_USER:
+ DATABASE_CRED += f"{DATABASE_USER}"
+if DATABASE_PASSWORD:
+ DATABASE_CRED += f":{DATABASE_PASSWORD}"
+if DATABASE_CRED:
+ DATABASE_CRED += "@"
+
+
+DB_VARS = {
+ "db_type": DATABASE_TYPE,
+ "db_cred": DATABASE_CRED,
+ "db_host": os.environ.get("DATABASE_HOST"),
+ "db_port": os.environ.get("DATABASE_PORT"),
+ "db_name": os.environ.get("DATABASE_NAME"),
+}
+
+if all(DB_VARS.values()):
+ DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
+
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
-DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
+DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None)
-if DATABASE_POOL_SIZE == "":
- DATABASE_POOL_SIZE = 0
-else:
+if DATABASE_POOL_SIZE != None:
try:
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
except Exception:
- DATABASE_POOL_SIZE = 0
+ DATABASE_POOL_SIZE = None
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
@@ -325,6 +348,7 @@ ENABLE_REALTIME_CHAT_SAVE = (
####################################
REDIS_URL = os.environ.get("REDIS_URL", "")
+REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
@@ -396,10 +420,33 @@ WEBUI_AUTH_COOKIE_SECURE = (
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
+ENABLE_COMPRESSION_MIDDLEWARE = (
+ os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
+)
+
+####################################
+# MODELS
+####################################
+
+MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
+if MODELS_CACHE_TTL == "":
+ MODELS_CACHE_TTL = None
+else:
+ try:
+ MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
+ except Exception:
+ MODELS_CACHE_TTL = 1
+
+
+####################################
+# WEBSOCKET SUPPORT
+####################################
+
ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
)
+
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
@@ -506,11 +553,14 @@ else:
# OFFLINE_MODE
####################################
+ENABLE_VERSION_UPDATE_CHECK = (
+ os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true"
+)
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
-
+ ENABLE_VERSION_UPDATE_CHECK = False
####################################
# AUDIT LOGGING
@@ -519,6 +569,14 @@ if OFFLINE_MODE:
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
+
+# Comma separated list of logger names to use for audit logging
+# Default is "uvicorn.access" which is the access log for Uvicorn
+# You can add more logger names to this list if you want to capture more logs
+AUDIT_UVICORN_LOGGER_NAMES = os.getenv(
+ "AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access"
+).split(",")
+
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
@@ -543,6 +601,9 @@ ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() ==
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
+OTEL_EXPORTER_OTLP_INSECURE = (
+ os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
+)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
@@ -550,6 +611,14 @@ OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
+OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
+OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
+
+
+OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
+ "OTEL_OTLP_SPAN_EXPORTER", "grpc"
+).lower() # grpc or http
+
####################################
# TOOLS/FUNCTIONS PIP OPTIONS
diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py
index 840f571cc9..e1ffc1eb27 100644
--- a/backend/open_webui/internal/db.py
+++ b/backend/open_webui/internal/db.py
@@ -62,6 +62,9 @@ def handle_peewee_migration(DATABASE_URL):
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
+ log.warning(
+ "Hint: If your database password contains special characters, you may need to URL-encode it."
+ )
raise
finally:
# Properly closing the database connection
@@ -81,20 +84,23 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
else:
- if DATABASE_POOL_SIZE > 0:
- engine = create_engine(
- SQLALCHEMY_DATABASE_URL,
- pool_size=DATABASE_POOL_SIZE,
- max_overflow=DATABASE_POOL_MAX_OVERFLOW,
- pool_timeout=DATABASE_POOL_TIMEOUT,
- pool_recycle=DATABASE_POOL_RECYCLE,
- pool_pre_ping=True,
- poolclass=QueuePool,
- )
+ if isinstance(DATABASE_POOL_SIZE, int):
+ if DATABASE_POOL_SIZE > 0:
+ engine = create_engine(
+ SQLALCHEMY_DATABASE_URL,
+ pool_size=DATABASE_POOL_SIZE,
+ max_overflow=DATABASE_POOL_MAX_OVERFLOW,
+ pool_timeout=DATABASE_POOL_TIMEOUT,
+ pool_recycle=DATABASE_POOL_RECYCLE,
+ pool_pre_ping=True,
+ poolclass=QueuePool,
+ )
+ else:
+ engine = create_engine(
+ SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
+ )
else:
- engine = create_engine(
- SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
- )
+ engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index e644e78897..f7bbedc0ed 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -36,7 +36,6 @@ from fastapi import (
applications,
BackgroundTasks,
)
-
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
+from starlette.datastructures import Headers
from open_webui.utils import logger
@@ -117,9 +117,14 @@ from open_webui.config import (
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
+
# SCIM
SCIM_ENABLED,
SCIM_TOKEN,
+
+ # Model list
+ ENABLE_BASE_MODELS_CACHE,
+
# Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE,
# Tool Server Configs
@@ -400,6 +405,7 @@ from open_webui.env import (
AUDIT_LOG_LEVEL,
CHANGELOG,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
GLOBAL_LOG_LEVEL,
@@ -415,10 +421,11 @@ from open_webui.env import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
+ ENABLE_COMPRESSION_MIDDLEWARE,
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
- OFFLINE_MODE,
+ ENABLE_VERSION_UPDATE_CHECK,
ENABLE_OTEL,
EXTERNAL_PWA_MANIFEST_URL,
AIOHTTP_CLIENT_SESSION_SSL,
@@ -453,7 +460,7 @@ from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import (
redis_task_command_listener,
- list_task_ids_by_chat_id,
+ list_task_ids_by_item_id,
stop_task,
list_tasks,
) # Import from tasks.py
@@ -537,6 +544,27 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup())
+ if app.state.config.ENABLE_BASE_MODELS_CACHE:
+ await get_all_models(
+ Request(
+ # Creating a mock request object to pass to get_all_models
+ {
+ "type": "http",
+ "asgi.version": "3.0",
+ "asgi.spec_version": "2.0",
+ "method": "GET",
+ "path": "/internal",
+ "query_string": b"",
+ "headers": Headers({}).raw,
+ "client": ("127.0.0.1", 12345),
+ "server": ("127.0.0.1", 80),
+ "scheme": "http",
+ "app": app,
+ }
+ ),
+ None,
+ )
+
yield
if hasattr(app.state, "redis_task_command_listener"):
@@ -557,6 +585,7 @@ app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
+ redis_key_prefix=REDIS_KEY_PREFIX,
)
app.state.redis = None
@@ -628,6 +657,15 @@ app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
app.state.config.SCIM_ENABLED = SCIM_ENABLED
app.state.config.SCIM_TOKEN = SCIM_TOKEN
+########################################
+#
+# MODELS
+#
+########################################
+
+app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE
+app.state.BASE_MODELS = []
+
########################################
#
# WEBUI
@@ -1085,7 +1123,9 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Add the middleware to the app
-app.add_middleware(CompressMiddleware)
+if ENABLE_COMPRESSION_MIDDLEWARE:
+ app.add_middleware(CompressMiddleware)
+
app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
@@ -1204,7 +1244,9 @@ if audit_level != AuditLevel.NONE:
@app.get("/api/models")
-async def get_models(request: Request, user=Depends(get_verified_user)):
+async def get_models(
+ request: Request, refresh: bool = False, user=Depends(get_verified_user)
+):
def get_filtered_models(models, user):
filtered_models = []
for model in models:
@@ -1228,7 +1270,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
- all_models = await get_all_models(request, user=user)
+ all_models = await get_all_models(request, refresh=refresh, user=user)
models = []
for model in all_models:
@@ -1463,7 +1505,7 @@ async def stop_task_endpoint(
request: Request, task_id: str, user=Depends(get_verified_user)
):
try:
- result = await stop_task(request, task_id)
+ result = await stop_task(request.app.state.redis, task_id)
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@@ -1471,7 +1513,7 @@ async def stop_task_endpoint(
@app.get("/api/tasks")
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
- return {"tasks": await list_tasks(request)}
+ return {"tasks": await list_tasks(request.app.state.redis)}
@app.get("/api/tasks/chat/{chat_id}")
@@ -1482,9 +1524,9 @@ async def list_tasks_by_chat_id_endpoint(
if chat is None or chat.user_id != user.id:
return {"task_ids": []}
- task_ids = await list_task_ids_by_chat_id(request, chat_id)
+ task_ids = await list_task_ids_by_item_id(request.app.state.redis, chat_id)
- print(f"Task IDs for chat {chat_id}: {task_ids}")
+ log.debug(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}
@@ -1537,6 +1579,7 @@ async def get_app_config(request: Request):
"enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
+ "enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK,
**(
{
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
@@ -1610,7 +1653,19 @@ async def get_app_config(request: Request):
),
}
if user is not None
- else {}
+ else {
+ **(
+ {
+ "metadata": {
+ "login_footer": app.state.LICENSE_METADATA.get(
+ "login_footer", ""
+ )
+ }
+ }
+ if app.state.LICENSE_METADATA
+ else {}
+ )
+ }
),
}
@@ -1642,9 +1697,9 @@ async def get_app_version():
@app.get("/api/version/updates")
async def get_app_latest_release_version(user=Depends(get_verified_user)):
- if OFFLINE_MODE:
+ if not ENABLE_VERSION_UPDATE_CHECK:
log.debug(
- f"Offline mode is enabled, returning current version as latest version"
+ f"Version update check is disabled, returning current version as latest version"
)
return {"current": VERSION, "latest": VERSION}
try:
diff --git a/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py
new file mode 100644
index 0000000000..3c916964e9
--- /dev/null
+++ b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py
@@ -0,0 +1,23 @@
+"""Update folder table data
+
+Revision ID: d31026856c01
+Revises: 9f0c9cd09105
+Create Date: 2025-07-13 03:00:00.000000
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "d31026856c01"
+down_revision = "9f0c9cd09105"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
+
+
+def downgrade():
+ op.drop_column("folder", "data")
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index 0ac53a0233..b9de2e5a4e 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
+from sqlalchemy.sql.expression import bindparam
####################
# Chat DB Schema
@@ -66,12 +67,14 @@ class ChatModel(BaseModel):
class ChatForm(BaseModel):
chat: dict
+ folder_id: Optional[str] = None
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
- folder_id: Optional[str] = None
+ created_at: Optional[int] = None
+ updated_at: Optional[int] = None
class ChatTitleMessagesForm(BaseModel):
@@ -118,6 +121,7 @@ class ChatTable:
else "New Chat"
),
"chat": form_data.chat,
+ "folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
@@ -147,8 +151,16 @@ class ChatTable:
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
- "created_at": int(time.time()),
- "updated_at": int(time.time()),
+ "created_at": (
+ form_data.created_at
+ if form_data.created_at
+ else int(time.time())
+ ),
+ "updated_at": (
+ form_data.updated_at
+ if form_data.updated_at
+ else int(time.time())
+ ),
}
)
@@ -232,6 +244,10 @@ class ChatTable:
if chat is None:
return None
+ # Sanitize message content for null characters before upserting
+ if isinstance(message.get("content"), str):
+ message["content"] = message["content"].replace("\x00", "")
+
chat = chat.chat
history = chat.get("history", {})
@@ -580,7 +596,7 @@ class ChatTable:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
- search_text = search_text.lower().strip()
+ search_text = search_text.replace("\u0000", "").lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(
@@ -614,21 +630,18 @@ class ChatTable:
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
+ sqlite_content_sql = (
+ "EXISTS ("
+ " SELECT 1 "
+ " FROM json_each(Chat.chat, '$.messages') AS message "
+ " WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'"
+ ")"
+ )
+ sqlite_content_clause = text(sqlite_content_sql)
query = query.filter(
- (
- Chat.title.ilike(
- f"%{search_text}%"
- ) # Case-insensitive search in title
- | text(
- """
- EXISTS (
- SELECT 1
- FROM json_each(Chat.chat, '$.messages') AS message
- WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
- )
- """
- )
- ).params(search_text=search_text)
+ or_(
+ Chat.title.ilike(bindparam("title_key")), sqlite_content_clause
+ ).params(title_key=f"%{search_text}%", content_key=search_text)
)
# Check if there are any tags to filter, it should have all the tags
@@ -663,21 +676,19 @@ class ChatTable:
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
+ postgres_content_sql = (
+ "EXISTS ("
+ " SELECT 1 "
+ " FROM json_array_elements(Chat.chat->'messages') AS message "
+ " WHERE LOWER(message->>'content') LIKE '%' || :content_key || '%'"
+ ")"
+ )
+ postgres_content_clause = text(postgres_content_sql)
query = query.filter(
- (
- Chat.title.ilike(
- f"%{search_text}%"
- ) # Case-insensitive search in title
- | text(
- """
- EXISTS (
- SELECT 1
- FROM json_array_elements(Chat.chat->'messages') AS message
- WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
- )
- """
- )
- ).params(search_text=search_text)
+ or_(
+ Chat.title.ilike(bindparam("title_key")),
+ postgres_content_clause,
+ ).params(title_key=f"%{search_text}%", content_key=search_text)
)
# Check if there are any tags to filter, it should have all the tags
diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py
index 1c97de26c9..56ef81f167 100644
--- a/backend/open_webui/models/folders.py
+++ b/backend/open_webui/models/folders.py
@@ -29,6 +29,7 @@ class Folder(Base):
name = Column(Text)
items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
+ data = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@@ -41,6 +42,7 @@ class FolderModel(BaseModel):
name: str
items: Optional[dict] = None
meta: Optional[dict] = None
+ data: Optional[dict] = None
is_expanded: bool = False
created_at: int
updated_at: int
@@ -55,6 +57,7 @@ class FolderModel(BaseModel):
class FolderForm(BaseModel):
name: str
+ data: Optional[dict] = None
model_config = ConfigDict(extra="allow")
@@ -187,8 +190,8 @@ class FolderTable:
log.error(f"update_folder: {e}")
return
- def update_folder_name_by_id_and_user_id(
- self, id: str, user_id: str, name: str
+ def update_folder_by_id_and_user_id(
+ self, id: str, user_id: str, form_data: FolderForm
) -> Optional[FolderModel]:
try:
with get_db() as db:
@@ -197,16 +200,28 @@ class FolderTable:
if not folder:
return None
+ form_data = form_data.model_dump(exclude_unset=True)
+
existing_folder = (
db.query(Folder)
- .filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
+ .filter_by(
+ name=form_data.get("name"),
+ parent_id=folder.parent_id,
+ user_id=user_id,
+ )
.first()
)
- if existing_folder:
+ if existing_folder and existing_folder.id != id:
return None
- folder.name = name
+ folder.name = form_data.get("name", folder.name)
+ if "data" in form_data:
+ folder.data = {
+ **(folder.data or {}),
+ **form_data["data"],
+ }
+
folder.updated_at = int(time.time())
db.commit()
diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py
index 114ccdc574..ce3b9f2e20 100644
--- a/backend/open_webui/models/notes.py
+++ b/backend/open_webui/models/notes.py
@@ -62,6 +62,13 @@ class NoteForm(BaseModel):
access_control: Optional[dict] = None
+class NoteUpdateForm(BaseModel):
+ title: Optional[str] = None
+ data: Optional[dict] = None
+ meta: Optional[dict] = None
+ access_control: Optional[dict] = None
+
+
class NoteUserResponse(NoteModel):
user: Optional[UserResponse] = None
@@ -110,16 +117,26 @@ class NoteTable:
note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None
- def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]:
+ def update_note_by_id(
+ self, id: str, form_data: NoteUpdateForm
+ ) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
if not note:
return None
- note.title = form_data.title
- note.data = form_data.data
- note.meta = form_data.meta
- note.access_control = form_data.access_control
+ form_data = form_data.model_dump(exclude_unset=True)
+
+ if "title" in form_data:
+ note.title = form_data["title"]
+ if "data" in form_data:
+ note.data = {**note.data, **form_data["data"]}
+ if "meta" in form_data:
+ note.meta = {**note.meta, **form_data["meta"]}
+
+ if "access_control" in form_data:
+ note.access_control = form_data["access_control"]
+
note.updated_at = int(time.time_ns())
db.commit()
diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py
index 8ac878fc22..e57323e1eb 100644
--- a/backend/open_webui/retrieval/loaders/main.py
+++ b/backend/open_webui/retrieval/loaders/main.py
@@ -14,7 +14,7 @@ from langchain_community.document_loaders import (
TextLoader,
UnstructuredEPubLoader,
UnstructuredExcelLoader,
- UnstructuredMarkdownLoader,
+ UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
@@ -226,7 +226,10 @@ class Loader:
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or (
- file_content_type and file_content_type.find("text/") >= 0
+ file_content_type
+ and file_content_type.find("text/") >= 0
+ # Avoid text/html files being detected as text
+ and not file_content_type.find("html") >= 0
)
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
@@ -389,6 +392,8 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
+ elif file_ext == "odt":
+ loader = UnstructuredODTLoader(file_path)
elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py
index b00e9d7ce5..b7f2622f5e 100644
--- a/backend/open_webui/retrieval/loaders/mistral.py
+++ b/backend/open_webui/retrieval/loaders/mistral.py
@@ -507,6 +507,7 @@ class MistralLoader:
timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually
+ trust_env=True,
) as session:
yield session
diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py
index 683f42819b..c0ad2b765b 100644
--- a/backend/open_webui/retrieval/utils.py
+++ b/backend/open_webui/retrieval/utils.py
@@ -7,6 +7,7 @@ import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
+from urllib.parse import quote
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
@@ -17,8 +18,11 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel
from open_webui.models.files import Files
+from open_webui.models.knowledge import Knowledges
+from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult
+from open_webui.utils.access_control import has_access
from open_webui.env import (
@@ -441,9 +445,9 @@ def get_embedding_function(
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
-def get_sources_from_files(
+def get_sources_from_items(
request,
- files,
+ items,
queries,
embedding_function,
k,
@@ -453,159 +457,206 @@ def get_sources_from_files(
hybrid_bm25_weight,
hybrid_search,
full_context=False,
+ user: Optional[UserModel] = None,
):
log.debug(
- f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
+ f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
- relevant_contexts = []
+ query_results = []
- for file in files:
+ for item in items:
+ query_result = None
+ collection_names = []
- context = None
- if file.get("docs"):
- # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
- context = {
- "documents": [[doc.get("content") for doc in file.get("docs")]],
- "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
- }
- elif file.get("context") == "full":
- # Manual Full Mode Toggle
- context = {
- "documents": [[file.get("file").get("data", {}).get("content")]],
- "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
- }
- elif (
- file.get("type") != "web_search"
- and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
- ):
- # BYPASS_EMBEDDING_AND_RETRIEVAL
- if file.get("type") == "collection":
- file_ids = file.get("data", {}).get("file_ids", [])
+ if item.get("type") == "text":
+ # Raw Text
+ # Used during temporary chat file uploads
- documents = []
- metadatas = []
- for file_id in file_ids:
- file_object = Files.get_file_by_id(file_id)
-
- if file_object:
- documents.append(file_object.data.get("content", ""))
- metadatas.append(
- {
- "file_id": file_id,
- "name": file_object.filename,
- "source": file_object.filename,
- }
- )
-
- context = {
- "documents": [documents],
- "metadatas": [metadatas],
+ if item.get("file"):
+ # if item has file data, use it
+ query_result = {
+ "documents": [[item.get("file").get("data", {}).get("content")]],
+ "metadatas": [[item.get("file").get("data", {}).get("meta", {})]],
+ }
+ else:
+ # Fallback to item content
+ query_result = {
+ "documents": [[item.get("content")]],
+ "metadatas": [
+ [{"file_id": item.get("id"), "name": item.get("name")}]
+ ],
}
- elif file.get("id"):
- file_object = Files.get_file_by_id(file.get("id"))
- if file_object:
- context = {
- "documents": [[file_object.data.get("content", "")]],
+ elif item.get("type") == "note":
+ # Note Attached
+ note = Notes.get_note_by_id(item.get("id"))
+
+ if user.role == "admin" or has_access(user.id, "read", note.access_control):
+ # User has access to the note
+ query_result = {
+ "documents": [[note.data.get("content", {}).get("md", "")]],
+ "metadatas": [[{"file_id": note.id, "name": note.title}]],
+ }
+
+ elif item.get("type") == "file":
+ if (
+ item.get("context") == "full"
+ or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
+ ):
+ if item.get("file").get("data", {}):
+ # Manual Full Mode Toggle
+ # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
+ query_result = {
+ "documents": [
+ [item.get("file").get("data", {}).get("content", "")]
+ ],
"metadatas": [
[
{
- "file_id": file.get("id"),
- "name": file_object.filename,
- "source": file_object.filename,
+ "file_id": item.get("id"),
+ "name": item.get("name"),
+ **item.get("file")
+ .get("data", {})
+ .get("metadata", {}),
}
]
],
}
- elif file.get("file").get("data"):
- context = {
- "documents": [[file.get("file").get("data", {}).get("content")]],
- "metadatas": [
- [file.get("file").get("data", {}).get("metadata", {})]
- ],
- }
- else:
- collection_names = []
- if file.get("type") == "collection":
- if file.get("legacy"):
- collection_names = file.get("collection_names", [])
+ elif item.get("id"):
+ file_object = Files.get_file_by_id(item.get("id"))
+ if file_object:
+ query_result = {
+ "documents": [[file_object.data.get("content", "")]],
+ "metadatas": [
+ [
+ {
+ "file_id": item.get("id"),
+ "name": file_object.filename,
+ "source": file_object.filename,
+ }
+ ]
+ ],
+ }
+ else:
+ # Fallback to collection names
+ if item.get("legacy"):
+ collection_names.append(f"{item['id']}")
else:
- collection_names.append(file["id"])
- elif file.get("collection_name"):
- collection_names.append(file["collection_name"])
- elif file.get("id"):
- if file.get("legacy"):
- collection_names.append(f"{file['id']}")
- else:
- collection_names.append(f"file-{file['id']}")
+ collection_names.append(f"file-{item['id']}")
+ elif item.get("type") == "collection":
+ if (
+ item.get("context") == "full"
+ or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
+ ):
+ # Manual Full Mode Toggle for Collection
+ knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
+
+ if knowledge_base and (
+ user.role == "admin"
+ or has_access(user.id, "read", knowledge_base.access_control)
+ ):
+
+ file_ids = knowledge_base.data.get("file_ids", [])
+
+ documents = []
+ metadatas = []
+ for file_id in file_ids:
+ file_object = Files.get_file_by_id(file_id)
+
+ if file_object:
+ documents.append(file_object.data.get("content", ""))
+ metadatas.append(
+ {
+ "file_id": file_id,
+ "name": file_object.filename,
+ "source": file_object.filename,
+ }
+ )
+
+ query_result = {
+ "documents": [documents],
+ "metadatas": [metadatas],
+ }
+ else:
+ # Fallback to collection names
+ if item.get("legacy"):
+ collection_names = item.get("collection_names", [])
+ else:
+ collection_names.append(item["id"])
+
+ elif item.get("docs"):
+ # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
+ query_result = {
+ "documents": [[doc.get("content") for doc in item.get("docs")]],
+ "metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
+ }
+ elif item.get("collection_name"):
+ # Direct Collection Name
+ collection_names.append(item["collection_name"])
+
+ # If query_result is None
+ # Fallback to collection names and vector search the collections
+ if query_result is None and collection_names:
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
- log.debug(f"skipping {file} as it has already been extracted")
+ log.debug(f"skipping {item} as it has already been extracted")
continue
- if full_context:
- try:
- context = get_all_items_from_collections(collection_names)
- except Exception as e:
- log.exception(e)
-
- else:
- try:
- context = None
- if file.get("type") == "text":
- context = file["content"]
- else:
- if hybrid_search:
- try:
- context = query_collection_with_hybrid_search(
- collection_names=collection_names,
- queries=queries,
- embedding_function=embedding_function,
- k=k,
- reranking_function=reranking_function,
- k_reranker=k_reranker,
- r=r,
- hybrid_bm25_weight=hybrid_bm25_weight,
- )
- except Exception as e:
- log.debug(
- "Error when using hybrid search, using"
- " non hybrid search as fallback."
- )
-
- if (not hybrid_search) or (context is None):
- context = query_collection(
+ try:
+ if full_context:
+ query_result = get_all_items_from_collections(collection_names)
+ else:
+ query_result = None # Initialize to None
+ if hybrid_search:
+ try:
+ query_result = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
+ reranking_function=reranking_function,
+ k_reranker=k_reranker,
+ r=r,
+ hybrid_bm25_weight=hybrid_bm25_weight,
)
- except Exception as e:
- log.exception(e)
+ except Exception as e:
+ log.debug(
+ "Error when using hybrid search, using non hybrid search as fallback."
+ )
+
+ # fallback to non-hybrid search
+ if not hybrid_search and query_result is None:
+ query_result = query_collection(
+ collection_names=collection_names,
+ queries=queries,
+ embedding_function=embedding_function,
+ k=k,
+ )
+ except Exception as e:
+ log.exception(e)
extracted_collections.extend(collection_names)
- if context:
- if "data" in file:
- del file["data"]
-
- relevant_contexts.append({**context, "file": file})
+ if query_result:
+ if "data" in item:
+ del item["data"]
+ query_results.append({**query_result, "file": item})
sources = []
- for context in relevant_contexts:
+ for query_result in query_results:
try:
- if "documents" in context:
- if "metadatas" in context:
+ if "documents" in query_result:
+ if "metadatas" in query_result:
source = {
- "source": context["file"],
- "document": context["documents"][0],
- "metadata": context["metadatas"][0],
+ "source": query_result["file"],
+ "document": query_result["documents"][0],
+ "metadata": query_result["metadatas"][0],
}
- if "distances" in context and context["distances"]:
- source["distances"] = context["distances"][0]
+ if "distances" in query_result and query_result["distances"]:
+ source["distances"] = query_result["distances"][0]
sources.append(source)
except Exception as e:
@@ -678,7 +729,7 @@ def generate_openai_batch_embeddings(
"Authorization": f"Bearer {key}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -727,7 +778,7 @@ def generate_azure_openai_batch_embeddings(
"api-key": key,
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -777,7 +828,7 @@ def generate_ollama_batch_embeddings(
"Authorization": f"Bearer {key}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py
index 60ef2d906c..7e16df3cfb 100644
--- a/backend/open_webui/retrieval/vector/dbs/opensearch.py
+++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py
@@ -157,10 +157,10 @@ class OpenSearchClient(VectorDBBase):
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
- size = limit if limit else 10
+ size = limit if limit else 10000
try:
result = self.client.search(
@@ -206,6 +206,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
@@ -228,6 +229,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def delete(
self,
@@ -251,11 +253,12 @@ class OpenSearchClient(VectorDBBase):
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py
index dfe2979076..2276e713fc 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py
@@ -18,6 +18,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -29,7 +30,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
@@ -86,6 +87,25 @@ class QdrantClient(VectorDBBase):
),
)
+ # Create payload indexes for efficient filtering
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.hash",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.file_id",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
index e83c437ef7..17c054ee50 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
@@ -1,5 +1,5 @@
import logging
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List, Dict, Any
from urllib.parse import urlparse
import grpc
@@ -9,6 +9,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
@@ -23,14 +24,28 @@ from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
NO_LIMIT = 999999999
+TENANT_ID_FIELD = "tenant_id"
+DEFAULT_DIMENSION = 384
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
+def _tenant_filter(tenant_id: str) -> models.FieldCondition:
+ return models.FieldCondition(
+ key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
+ )
+
+
+def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
+ return models.FieldCondition(
+ key=f"metadata.{key}", match=models.MatchValue(value=value)
+ )
+
+
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
@@ -38,24 +53,26 @@ class QdrantClient(VectorDBBase):
self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI:
- self.client = None
- return
+ raise ValueError(
+ "QDRANT_URI is not set. Please configure it in the environment variables."
+ )
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
- if self.PREFER_GRPC:
- self.client = Qclient(
+ self.client = (
+ Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
)
- else:
- self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
+ if self.PREFER_GRPC
+ else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
+ )
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
@@ -65,23 +82,13 @@ class QdrantClient(VectorDBBase):
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult:
- ids = []
- documents = []
- metadatas = []
-
+ ids, documents, metadatas = [], [], []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
-
- return GetResult(
- **{
- "ids": [ids],
- "documents": [documents],
- "metadatas": [metadatas],
- }
- )
+ return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
"""
@@ -113,143 +120,47 @@ class QdrantClient(VectorDBBase):
else:
return self.KNOWLEDGE_COLLECTION, tenant_id
- def _extract_error_message(self, exception):
- """
- Extract error message from either HTTP or gRPC exceptions
-
- Returns:
- tuple: (status_code, error_message)
- """
- # Check if it's an HTTP exception
- if isinstance(exception, UnexpectedResponse):
- try:
- error_data = exception.structured()
- error_msg = error_data.get("status", {}).get("error", "")
- return exception.status_code, error_msg
- except Exception as inner_e:
- log.error(f"Failed to parse HTTP error: {inner_e}")
- return exception.status_code, str(exception)
-
- # Check if it's a gRPC exception
- elif isinstance(exception, grpc.RpcError):
- # Extract status code from gRPC error
- status_code = None
- if hasattr(exception, "code") and callable(exception.code):
- status_code = exception.code().value[0]
-
- # Extract error message
- error_msg = str(exception)
- if "details =" in error_msg:
- # Parse the details line which contains the actual error message
- try:
- details_line = [
- line.strip()
- for line in error_msg.split("\n")
- if "details =" in line
- ][0]
- error_msg = details_line.split("details =")[1].strip(' "')
- except (IndexError, AttributeError):
- # Fall back to full message if parsing fails
- pass
-
- return status_code, error_msg
-
- # For any other type of exception
- return None, str(exception)
-
- def _is_collection_not_found_error(self, exception):
- """
- Check if the exception is due to collection not found, supporting both HTTP and gRPC
- """
- status_code, error_msg = self._extract_error_message(exception)
-
- # HTTP error (404)
- if (
- status_code == 404
- and "Collection" in error_msg
- and "doesn't exist" in error_msg
- ):
- return True
-
- # gRPC error (NOT_FOUND status)
- if (
- isinstance(exception, grpc.RpcError)
- and exception.code() == grpc.StatusCode.NOT_FOUND
- ):
- return True
-
- return False
-
- def _is_dimension_mismatch_error(self, exception):
- """
- Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
- """
- status_code, error_msg = self._extract_error_message(exception)
-
- # Common patterns in both HTTP and gRPC
- return (
- "Vector dimension error" in error_msg
- or "dimensions mismatch" in error_msg
- or "invalid vector size" in error_msg
- )
-
- def _create_multi_tenant_collection_if_not_exists(
- self, mt_collection_name: str, dimension: int = 384
+ def _create_multi_tenant_collection(
+ self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
):
"""
- Creates a collection with multi-tenancy configuration if it doesn't exist.
- Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
- When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
+ Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
"""
- try:
- # Try to create the collection directly - will fail if it already exists
- self.client.create_collection(
- collection_name=mt_collection_name,
- vectors_config=models.VectorParams(
- size=dimension,
- distance=models.Distance.COSINE,
- on_disk=self.QDRANT_ON_DISK,
- ),
- hnsw_config=models.HnswConfigDiff(
- payload_m=16, # Enable per-tenant indexing
- m=0,
- on_disk=self.QDRANT_ON_DISK,
- ),
- )
+ self.client.create_collection(
+ collection_name=mt_collection_name,
+ vectors_config=models.VectorParams(
+ size=dimension,
+ distance=models.Distance.COSINE,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+ log.info(
+ f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
+ )
- # Create tenant ID payload index
+ self.client.create_payload_index(
+ collection_name=mt_collection_name,
+ field_name=TENANT_ID_FIELD,
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=True,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+
+ for field in ("metadata.hash", "metadata.file_id"):
self.client.create_payload_index(
collection_name=mt_collection_name,
- field_name="tenant_id",
+ field_name=field,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
- is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
- wait=True,
)
- log.info(
- f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
- )
- except (UnexpectedResponse, grpc.RpcError) as e:
- # Check for the specific error indicating collection already exists
- status_code, error_msg = self._extract_error_message(e)
-
- # HTTP status code 409 or gRPC ALREADY_EXISTS
- if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
- isinstance(e, grpc.RpcError)
- and e.code() == grpc.StatusCode.ALREADY_EXISTS
- ):
- if "already exists" in error_msg:
- log.debug(f"Collection {mt_collection_name} already exists")
- return
- # If it's not an already exists error, re-raise
- raise e
- except Exception as e:
- raise e
-
- def _create_points(self, items: list[VectorItem], tenant_id: str):
+ def _create_points(
+ self, items: List[VectorItem], tenant_id: str
+ ) -> List[PointStruct]:
"""
Create point structs from vector items with tenant ID.
"""
@@ -260,56 +171,42 @@ class QdrantClient(VectorDBBase):
payload={
"text": item["text"],
"metadata": item["metadata"],
- "tenant_id": tenant_id,
+ TENANT_ID_FIELD: tenant_id,
},
)
for item in items
]
+ def _ensure_collection(
+ self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
+ ):
+ """
+ Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
+ """
+ if not self.client.collection_exists(collection_name=mt_collection_name):
+ self._create_multi_tenant_collection(mt_collection_name, dimension)
+
def has_collection(self, collection_name: str) -> bool:
"""
Check if a logical collection exists by checking for any points with the tenant ID.
"""
if not self.client:
return False
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- try:
- # Try directly querying - most of the time collection should exist
- response = self.client.query_points(
- collection_name=mt_collection,
- query_filter=models.Filter(must=[tenant_filter]),
- limit=1,
- )
-
- # Collection exists with this tenant ID if there are points
- return len(response.points) > 0
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(f"Collection {mt_collection} doesn't exist")
- return False
- else:
- # For other API errors, log and return False
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error: {error_msg}")
- return False
- except Exception as e:
- # For any other errors, log and return False
- log.debug(f"Error checking collection {mt_collection}: {e}")
+ if not self.client.collection_exists(collection_name=mt_collection):
return False
+ tenant_filter = _tenant_filter(tenant_id)
+ count_result = self.client.count(
+ collection_name=mt_collection,
+ count_filter=models.Filter(must=[tenant_filter]),
+ )
+ return count_result.count > 0
def delete(
self,
collection_name: str,
- ids: Optional[list[str]] = None,
- filter: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ filter: Optional[Dict[str, Any]] = None,
):
"""
Delete vectors by ID or filter from a collection with tenant isolation.
@@ -317,189 +214,76 @@ class QdrantClient(VectorDBBase):
if not self.client:
return None
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
+ return None
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
+ must_conditions = [_tenant_filter(tenant_id)]
+ should_conditions = []
+ if ids:
+ should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
+ elif filter:
+ must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
+
+ return self.client.delete(
+ collection_name=mt_collection,
+ points_selector=models.FilterSelector(
+ filter=models.Filter(must=must_conditions, should=should_conditions)
+ ),
)
- must_conditions = [tenant_filter]
- should_conditions = []
-
- if ids:
- for id_value in ids:
- should_conditions.append(
- models.FieldCondition(
- key="metadata.id",
- match=models.MatchValue(value=id_value),
- ),
- )
- elif filter:
- for key, value in filter.items():
- must_conditions.append(
- models.FieldCondition(
- key=f"metadata.{key}",
- match=models.MatchValue(value=value),
- ),
- )
-
- try:
- # Try to delete directly - most of the time collection should exist
- update_result = self.client.delete(
- collection_name=mt_collection,
- points_selector=models.FilterSelector(
- filter=models.Filter(must=must_conditions, should=should_conditions)
- ),
- )
-
- return update_result
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(
- f"Collection {mt_collection} doesn't exist, nothing to delete"
- )
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, re-raise
- raise
-
def search(
- self, collection_name: str, vectors: list[list[float | int]], limit: int
+ self, collection_name: str, vectors: List[List[float | int]], limit: int
) -> Optional[SearchResult]:
"""
Search for the nearest neighbor items based on the vectors with tenant isolation.
"""
- if not self.client:
+ if not self.client or not vectors:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Get the vector dimension from the query vector
- dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
-
- try:
- # Try the search operation directly - most of the time collection should exist
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- # Ensure vector dimensions match the collection
- collection_dim = self.client.get_collection(
- mt_collection
- ).config.params.vectors.size
-
- if collection_dim != dimension:
- if collection_dim < dimension:
- vectors = [vector[:collection_dim] for vector in vectors]
- else:
- vectors = [
- vector + [0] * (collection_dim - dimension)
- for vector in vectors
- ]
-
- # Search with tenant filter
- prefetch_query = models.Prefetch(
- filter=models.Filter(must=[tenant_filter]),
- limit=NO_LIMIT,
- )
- query_response = self.client.query_points(
- collection_name=mt_collection,
- query=vectors[0],
- prefetch=prefetch_query,
- limit=limit,
- )
-
- get_result = self._result_to_get_result(query_response.points)
- return SearchResult(
- ids=get_result.ids,
- documents=get_result.documents,
- metadatas=get_result.metadatas,
- # qdrant distance is [-1, 1], normalize to [0, 1]
- distances=[
- [(point.score + 1.0) / 2.0 for point in query_response.points]
- ],
- )
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(
- f"Collection {mt_collection} doesn't exist, search returns None"
- )
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error during search: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, log and return None
- log.exception(f"Error searching collection '{collection_name}': {e}")
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
return None
- def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
+ tenant_filter = _tenant_filter(tenant_id)
+ query_response = self.client.query_points(
+ collection_name=mt_collection,
+ query=vectors[0],
+ limit=limit,
+ query_filter=models.Filter(must=[tenant_filter]),
+ )
+ get_result = self._result_to_get_result(query_response.points)
+ return SearchResult(
+ ids=get_result.ids,
+ documents=get_result.documents,
+ metadatas=get_result.metadatas,
+ distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
+ )
+
+ def query(
+ self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
+ ):
"""
Query points with filters and tenant isolation.
"""
if not self.client:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Set default limit if not provided
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
+ return None
if limit is None:
limit = NO_LIMIT
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- # Create metadata filters
- field_conditions = []
- for key, value in filter.items():
- field_conditions.append(
- models.FieldCondition(
- key=f"metadata.{key}", match=models.MatchValue(value=value)
- )
- )
-
- # Combine tenant filter with metadata filters
+ tenant_filter = _tenant_filter(tenant_id)
+ field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
-
- try:
- # Try the query directly - most of the time collection should exist
- points = self.client.query_points(
- collection_name=mt_collection,
- query_filter=combined_filter,
- limit=limit,
- )
-
- return self._result_to_get_result(points.points)
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(
- f"Collection {mt_collection} doesn't exist, query returns None"
- )
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error during query: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, log and re-raise
- log.exception(f"Error querying collection '{collection_name}': {e}")
- return None
+ points = self.client.query_points(
+ collection_name=mt_collection,
+ query_filter=combined_filter,
+ limit=limit,
+ )
+ return self._result_to_get_result(points.points)
def get(self, collection_name: str) -> Optional[GetResult]:
"""
@@ -507,169 +291,36 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Create tenant filter
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- try:
- # Try to get points directly - most of the time collection should exist
- points = self.client.query_points(
- collection_name=mt_collection,
- query_filter=models.Filter(must=[tenant_filter]),
- limit=NO_LIMIT,
- )
-
- return self._result_to_get_result(points.points)
- except (UnexpectedResponse, grpc.RpcError) as e:
- if self._is_collection_not_found_error(e):
- log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
- return None
- else:
- # For other API errors, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unexpected Qdrant error during get: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, log and return None
- log.exception(f"Error getting collection '{collection_name}': {e}")
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
-
- def _handle_operation_with_error_retry(
- self, operation_name, mt_collection, points, dimension
- ):
- """
- Private helper to handle common error cases for insert and upsert operations.
-
- Args:
- operation_name: 'insert' or 'upsert'
- mt_collection: The multi-tenant collection name
- points: The vector points to insert/upsert
- dimension: The dimension of the vectors
-
- Returns:
- The operation result (for upsert) or None (for insert)
- """
- try:
- if operation_name == "insert":
- self.client.upload_points(mt_collection, points)
- return None
- else: # upsert
- return self.client.upsert(mt_collection, points)
- except (UnexpectedResponse, grpc.RpcError) as e:
- # Handle collection not found
- if self._is_collection_not_found_error(e):
- log.info(
- f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
- )
- # Create collection with correct dimensions from our vectors
- self._create_multi_tenant_collection_if_not_exists(
- mt_collection_name=mt_collection, dimension=dimension
- )
- # Try operation again - no need for dimension adjustment since we just created with correct dimensions
- if operation_name == "insert":
- self.client.upload_points(mt_collection, points)
- return None
- else: # upsert
- return self.client.upsert(mt_collection, points)
-
- # Handle dimension mismatch
- elif self._is_dimension_mismatch_error(e):
- # For dimension errors, the collection must exist, so get its configuration
- mt_collection_info = self.client.get_collection(mt_collection)
- existing_size = mt_collection_info.config.params.vectors.size
-
- log.info(
- f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
- )
-
- if existing_size < dimension:
- # Truncate vectors to fit
- log.info(
- f"Truncating vectors from {dimension} to {existing_size} dimensions"
- )
- points = [
- PointStruct(
- id=point.id,
- vector=point.vector[:existing_size],
- payload=point.payload,
- )
- for point in points
- ]
- elif existing_size > dimension:
- # Pad vectors with zeros
- log.info(
- f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
- )
- points = [
- PointStruct(
- id=point.id,
- vector=point.vector
- + [0] * (existing_size - len(point.vector)),
- payload=point.payload,
- )
- for point in points
- ]
- # Try operation again with adjusted dimensions
- if operation_name == "insert":
- self.client.upload_points(mt_collection, points)
- return None
- else: # upsert
- return self.client.upsert(mt_collection, points)
- else:
- # Not a known error we can handle, log and re-raise
- _, error_msg = self._extract_error_message(e)
- log.warning(f"Unhandled Qdrant error: {error_msg}")
- raise
- except Exception as e:
- # For non-Qdrant exceptions, re-raise
- raise
-
- def insert(self, collection_name: str, items: list[VectorItem]):
- """
- Insert items with tenant ID.
- """
- if not self.client or not items:
- return None
-
- # Map to multi-tenant collection and tenant ID
- mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Get dimensions from the actual vectors
- dimension = len(items[0]["vector"]) if items else None
-
- # Create points with tenant ID
- points = self._create_points(items, tenant_id)
-
- # Handle the operation with error retry
- return self._handle_operation_with_error_retry(
- "insert", mt_collection, points, dimension
+ tenant_filter = _tenant_filter(tenant_id)
+ points = self.client.query_points(
+ collection_name=mt_collection,
+ query_filter=models.Filter(must=[tenant_filter]),
+ limit=NO_LIMIT,
)
+ return self._result_to_get_result(points.points)
- def upsert(self, collection_name: str, items: list[VectorItem]):
+ def upsert(self, collection_name: str, items: List[VectorItem]):
"""
Upsert items with tenant ID.
"""
if not self.client or not items:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- # Get dimensions from the actual vectors
- dimension = len(items[0]["vector"]) if items else None
-
- # Create points with tenant ID
+ dimension = len(items[0]["vector"])
+ self._ensure_collection(mt_collection, dimension)
points = self._create_points(items, tenant_id)
+ self.client.upload_points(mt_collection, points)
+ return None
- # Handle the operation with error retry
- return self._handle_operation_with_error_retry(
- "upsert", mt_collection, points, dimension
- )
+ def insert(self, collection_name: str, items: List[VectorItem]):
+ """
+ Insert items with tenant ID.
+ """
+ return self.upsert(collection_name, items)
def reset(self):
"""
@@ -677,11 +328,9 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
-
- collection_names = self.client.get_collections().collections
- for collection_name in collection_names:
- if collection_name.name.startswith(self.collection_prefix):
- self.client.delete_collection(collection_name=collection_name.name)
+ for collection in self.client.get_collections().collections:
+ if collection.name.startswith(self.collection_prefix):
+ self.client.delete_collection(collection_name=collection.name)
def delete_collection(self, collection_name: str):
"""
@@ -689,24 +338,13 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
-
- # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
-
- tenant_filter = models.FieldCondition(
- key="tenant_id", match=models.MatchValue(value=tenant_id)
- )
-
- field_conditions = [tenant_filter]
-
- update_result = self.client.delete(
+ if not self.client.collection_exists(collection_name=mt_collection):
+ log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
+ return None
+ self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
- filter=models.Filter(must=field_conditions)
+ filter=models.Filter(must=[_tenant_filter(tenant_id)])
),
)
-
- if self.client.get_collection(mt_collection).points_count == 0:
- self.client.delete_collection(mt_collection)
-
- return update_result
diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py
index 3075db990f..7bea575620 100644
--- a/backend/open_webui/retrieval/web/brave.py
+++ b/backend/open_webui/retrieval/web/brave.py
@@ -36,7 +36,9 @@ def search_brave(
return [
SearchResult(
- link=result["url"], title=result.get("title"), snippet=result.get("snippet")
+ link=result["url"],
+ title=result.get("title"),
+ snippet=result.get("description"),
)
for result in results[:count]
]
diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py
index bf8ae6880b..a32fc358ed 100644
--- a/backend/open_webui/retrieval/web/duckduckgo.py
+++ b/backend/open_webui/retrieval/web/duckduckgo.py
@@ -2,8 +2,8 @@ import logging
from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
-from duckduckgo_search import DDGS
-from duckduckgo_search.exceptions import RatelimitException
+from ddgs import DDGS
+from ddgs.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py
index 27634cec19..c63ad3bfe7 100644
--- a/backend/open_webui/routers/audio.py
+++ b/backend/open_webui/routers/audio.py
@@ -15,6 +15,7 @@ import aiohttp
import aiofiles
import requests
import mimetypes
+from urllib.parse import quote
from fastapi import (
Depends,
@@ -327,6 +328,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
+ r = None
if request.app.state.config.TTS_ENGINE == "openai":
payload["model"] = request.app.state.config.TTS_MODEL
@@ -335,7 +337,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
- async with session.post(
+ r = await session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload,
headers={
@@ -343,7 +345,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -353,14 +355,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as r:
- r.raise_for_status()
+ )
- async with aiofiles.open(file_path, "wb") as f:
- await f.write(await r.read())
+ r.raise_for_status()
- async with aiofiles.open(file_body_path, "w") as f:
- await f.write(json.dumps(payload))
+ async with aiofiles.open(file_path, "wb") as f:
+ await f.write(await r.read())
+
+ async with aiofiles.open(file_body_path, "w") as f:
+ await f.write(json.dumps(payload))
return FileResponse(file_path)
@@ -368,18 +371,18 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e)
detail = None
- try:
- if r.status != 200:
- res = await r.json()
+ status_code = 500
+ detail = f"Open WebUI: Server Connection Error"
- if "error" in res:
- detail = f"External: {res['error'].get('message', '')}"
- except Exception:
- detail = f"External: {e}"
+ if r is not None:
+ status_code = r.status
+ res = await r.json()
+ if "error" in res:
+ detail = f"External: {res['error'].get('message', '')}"
raise HTTPException(
- status_code=getattr(r, "status", 500) if r else 500,
- detail=detail if detail else "Open WebUI: Server Connection Error",
+ status_code=status_code,
+ detail=detail,
)
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
@@ -919,14 +922,18 @@ def transcription(
):
log.info(f"file.content_type: {file.content_type}")
- supported_content_types = request.app.state.config.STT_SUPPORTED_CONTENT_TYPES or [
- "audio/*",
- "video/webm",
- ]
+ stt_supported_content_types = getattr(
+ request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
+ )
if not any(
fnmatch(file.content_type, content_type)
- for content_type in supported_content_types
+ for content_type in (
+ stt_supported_content_types
+ if stt_supported_content_types
+ and any(t.strip() for t in stt_supported_content_types)
+ else ["audio/*", "video/webm"]
+ )
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py
index 60a12db4b3..fd3c317ab9 100644
--- a/backend/open_webui/routers/auths.py
+++ b/backend/open_webui/routers/auths.py
@@ -669,12 +669,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout")
async def signout(request: Request, response: Response):
response.delete_cookie("token")
+ response.delete_cookie("oui-session")
if ENABLE_OAUTH_SIGNUP.value:
oauth_id_token = request.cookies.get("oauth_id_token")
if oauth_id_token:
try:
- async with ClientSession() as session:
+ async with ClientSession(trust_env=True) as session:
async with session.get(OPENID_PROVIDER_URL.value) as resp:
if resp.status == 200:
openid_data = await resp.json()
@@ -686,7 +687,12 @@ async def signout(request: Request, response: Response):
status_code=200,
content={
"status": True,
- "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
+ "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}"
+ + (
+ f"&post_logout_redirect_uri={WEBUI_AUTH_SIGNOUT_REDIRECT_URL}"
+ if WEBUI_AUTH_SIGNOUT_REDIRECT_URL
+ else ""
+ ),
},
headers=response.headers,
)
diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py
index 6da3f04cee..a4173fbd8d 100644
--- a/backend/open_webui/routers/channels.py
+++ b/backend/open_webui/routers/channels.py
@@ -40,10 +40,14 @@ router = APIRouter()
@router.get("/", response_model=list[ChannelModel])
async def get_channels(user=Depends(get_verified_user)):
+ return Channels.get_channels_by_user_id(user.id)
+
+
+@router.get("/list", response_model=list[ChannelModel])
+async def get_all_channels(user=Depends(get_verified_user)):
if user.role == "admin":
return Channels.get_channels()
- else:
- return Channels.get_channels_by_user_id(user.id)
+ return Channels.get_channels_by_user_id(user.id)
############################
diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py
index 29b12ed676..13b6040102 100644
--- a/backend/open_webui/routers/chats.py
+++ b/backend/open_webui/routers/chats.py
@@ -684,8 +684,10 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
- if not has_permission(
- user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ if (user.role != "admin") and (
+ not has_permission(
+ user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ )
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index 8fe83eea92..cac218ccac 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -39,32 +39,39 @@ async def export_config(user=Depends(get_admin_user)):
############################
-# Direct Connections Config
+# Connections Config
############################
-class DirectConnectionsConfigForm(BaseModel):
+class ConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool
+ ENABLE_BASE_MODELS_CACHE: bool
-@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+@router.get("/connections", response_model=ConnectionsConfigForm)
+async def get_connections_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
-@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def set_direct_connections_config(
+@router.post("/connections", response_model=ConnectionsConfigForm)
+async def set_connections_config(
request: Request,
- form_data: DirectConnectionsConfigForm,
+ form_data: ConnectionsConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS
)
+ request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
+ form_data.ENABLE_BASE_MODELS_CACHE
+ )
+
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py
index b9bb15c7b4..bdf5780fc4 100644
--- a/backend/open_webui/routers/files.py
+++ b/backend/open_webui/routers/files.py
@@ -155,17 +155,18 @@ def upload_file(
if process:
try:
if file.content_type:
- stt_supported_content_types = (
- request.app.state.config.STT_SUPPORTED_CONTENT_TYPES
- or [
- "audio/*",
- "video/webm",
- ]
+ stt_supported_content_types = getattr(
+ request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
)
if any(
fnmatch(file.content_type, content_type)
- for content_type in stt_supported_content_types
+ for content_type in (
+ stt_supported_content_types
+ if stt_supported_content_types
+ and any(t.strip() for t in stt_supported_content_types)
+ else ["audio/*", "video/webm"]
+ )
):
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata)
diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py
index 2c41c92854..edc9f85ff2 100644
--- a/backend/open_webui/routers/folders.py
+++ b/backend/open_webui/routers/folders.py
@@ -120,16 +120,14 @@ async def update_folder_name_by_id(
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
folder.parent_id, user.id, form_data.name
)
- if existing_folder:
+ if existing_folder and existing_folder.id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
)
try:
- folder = Folders.update_folder_name_by_id_and_user_id(
- id, user.id, form_data.name
- )
+ folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data)
return folder
except Exception as e:
diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py
index 355093335a..96d8215fb3 100644
--- a/backend/open_webui/routers/functions.py
+++ b/backend/open_webui/routers/functions.py
@@ -105,7 +105,7 @@ async def load_function_from_url(
)
try:
- async with aiohttp.ClientSession() as session:
+ async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index 52686a5841..5cd07e3d54 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -8,6 +8,7 @@ import re
from pathlib import Path
from typing import Optional
+from urllib.parse import quote
import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
@@ -302,8 +303,16 @@ async def update_image_config(
):
set_image_model(request, form_data.MODEL)
+ if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1":
+ raise HTTPException(
+ status_code=400,
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(
+ " (auto is only allowed with gpt-image-1)."
+ ),
+ )
+
pattern = r"^\d+x\d+$"
- if re.match(pattern, form_data.IMAGE_SIZE):
+ if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
@@ -471,7 +480,14 @@ async def image_generations(
form_data: GenerateImageForm,
user=Depends(get_verified_user),
):
- width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
+ # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
+ # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
+ # image model other than gpt-image-1, which is warned about on settings save
+ width, height = (
+ tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
+ if "x" in request.app.state.config.IMAGE_SIZE
+ else (512, 512)
+ )
r = None
try:
@@ -483,7 +499,7 @@ async def image_generations(
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
- headers["X-OpenWebUI-User-Name"] = user.name
+ headers["X-OpenWebUI-User-Name"] = quote(user.name, safe=" ")
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py
index 2cbbd331b5..0c4842909f 100644
--- a/backend/open_webui/routers/notes.py
+++ b/backend/open_webui/routers/notes.py
@@ -51,7 +51,14 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
return notes
-@router.get("/list", response_model=list[NoteUserResponse])
+class NoteTitleIdResponse(BaseModel):
+ id: str
+ title: str
+ updated_at: int
+ created_at: int
+
+
+@router.get("/list", response_model=list[NoteTitleIdResponse])
async def get_note_list(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
@@ -63,13 +70,8 @@ async def get_note_list(request: Request, user=Depends(get_verified_user)):
)
notes = [
- NoteUserResponse(
- **{
- **note.model_dump(),
- "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
- }
- )
- for note in Notes.get_notes_by_user_id(user.id, "read")
+ NoteTitleIdResponse(**note.model_dump())
+ for note in Notes.get_notes_by_user_id(user.id, "write")
]
return notes
diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py
index 1353599374..000f4b48d2 100644
--- a/backend/open_webui/routers/ollama.py
+++ b/backend/open_webui/routers/ollama.py
@@ -16,6 +16,7 @@ from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
+from urllib.parse import quote
from open_webui.models.chats import Chats
from open_webui.models.users import UserModel
@@ -58,6 +59,7 @@ from open_webui.config import (
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -87,7 +89,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -138,7 +140,7 @@ async def send_post_request(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -242,7 +244,7 @@ async def verify_connection(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -329,7 +331,7 @@ def merge_ollama_models_lists(model_lists):
return list(merged_models.values())
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
@@ -462,7 +464,7 @@ async def get_ollama_tags(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -634,7 +636,10 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
class ModelNameForm(BaseModel):
- name: str
+ model: Optional[str] = None
+ model_config = ConfigDict(
+ extra="allow",
+ )
@router.post("/api/unload")
@@ -643,10 +648,12 @@ async def unload_model(
form_data: ModelNameForm,
user=Depends(get_admin_user),
):
- model_name = form_data.name
+ form_data = form_data.model_dump(exclude_none=True)
+ model_name = form_data.get("model", form_data.get("name"))
+
if not model_name:
raise HTTPException(
- status_code=400, detail="Missing 'name' of model to unload."
+ status_code=400, detail="Missing name of the model to unload."
)
# Refresh/load models if needed, get mapping from name to URLs
@@ -709,11 +716,14 @@ async def pull_model(
url_idx: int = 0,
user=Depends(get_admin_user),
):
+ form_data = form_data.model_dump(exclude_none=True)
+ form_data["model"] = form_data.get("model", form_data.get("name"))
+
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
# Admin should be able to pull models from any source
- payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
+ payload = {**form_data, "insecure": True}
return await send_post_request(
url=f"{url}/api/pull",
@@ -724,7 +734,7 @@ async def pull_model(
class PushModelForm(BaseModel):
- name: str
+ model: str
insecure: Optional[bool] = None
stream: Optional[bool] = None
@@ -741,12 +751,12 @@ async def push_model(
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
- if form_data.name in models:
- url_idx = models[form_data.name]["urls"][0]
+ if form_data.model in models:
+ url_idx = models[form_data.model]["urls"][0]
else:
raise HTTPException(
status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -824,7 +834,7 @@ async def copy_model(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -865,16 +875,21 @@ async def delete_model(
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
+ form_data = form_data.model_dump(exclude_none=True)
+ form_data["model"] = form_data.get("model", form_data.get("name"))
+
+ model = form_data.get("model")
+
if url_idx is None:
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
- if form_data.name in models:
- url_idx = models[form_data.name]["urls"][0]
+ if model in models:
+ url_idx = models[model]["urls"][0]
else:
raise HTTPException(
status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -884,13 +899,13 @@ async def delete_model(
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
- data=form_data.model_dump_json(exclude_none=True).encode(),
+ data=json.dumps(form_data).encode(),
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -926,16 +941,21 @@ async def delete_model(
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
+ form_data = form_data.model_dump(exclude_none=True)
+ form_data["model"] = form_data.get("model", form_data.get("name"))
+
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
- if form_data.name not in models:
+ model = form_data.get("model")
+
+ if model not in models:
raise HTTPException(
status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
- url_idx = random.choice(models[form_data.name]["urls"])
+ url_idx = random.choice(models[model]["urls"])
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -949,7 +969,7 @@ async def show_model_info(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -958,7 +978,7 @@ async def show_model_info(
else {}
),
},
- data=form_data.model_dump_json(exclude_none=True).encode(),
+ data=json.dumps(form_data).encode(),
)
r.raise_for_status()
@@ -1036,7 +1056,7 @@ async def embed(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -1123,7 +1143,7 @@ async def embeddings(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index 7649271fee..a759ec7eee 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -8,7 +8,7 @@ from typing import Literal, Optional, overload
import aiohttp
from aiocache import cached
import requests
-
+from urllib.parse import quote
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
@@ -21,6 +21,7 @@ from open_webui.config import (
CACHE_DIR,
)
from open_webui.env import (
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -66,7 +67,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -225,7 +226,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -386,7 +387,7 @@ async def get_filtered_models(models, user):
return filtered_models
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
@@ -478,7 +479,7 @@ async def get_models(
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -573,7 +574,7 @@ async def verify_connection(
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -633,13 +634,7 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail)
-def convert_to_azure_payload(
- url,
- payload: dict,
-):
- model = payload.get("model", "")
-
- # Filter allowed parameters based on Azure OpenAI API
+def get_azure_allowed_params(api_version: str) -> set[str]:
allowed_params = {
"messages",
"temperature",
@@ -669,6 +664,23 @@ def convert_to_azure_payload(
"max_completion_tokens",
}
+ try:
+ if api_version >= "2024-09-01-preview":
+ allowed_params.add("stream_options")
+ except ValueError:
+ log.debug(
+ f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters."
+ )
+
+ return allowed_params
+
+
+def convert_to_azure_payload(url, payload: dict, api_version: str):
+ model = payload.get("model", "")
+
+ # Filter allowed parameters based on Azure OpenAI API
+ allowed_params = get_azure_allowed_params(api_version)
+
# Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"):
# Convert max_tokens to max_completion_tokens for o-series models
@@ -806,7 +818,7 @@ async def generate_chat_completion(
),
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -817,8 +829,8 @@ async def generate_chat_completion(
}
if api_config.get("azure", False):
- request_url, payload = convert_to_azure_payload(url, payload)
- api_version = api_config.get("api_version", "") or "2023-03-15-preview"
+ api_version = api_config.get("api_version", "2023-03-15-preview")
+ request_url, payload = convert_to_azure_payload(url, payload, api_version)
headers["api-key"] = key
headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}"
@@ -924,7 +936,7 @@ async def embeddings(request: Request, form_data: dict, user):
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -996,7 +1008,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
@@ -1007,16 +1019,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
}
if api_config.get("azure", False):
+ api_version = api_config.get("api_version", "2023-03-15-preview")
headers["api-key"] = key
- headers["api-version"] = (
- api_config.get("api_version", "") or "2023-03-15-preview"
- )
+ headers["api-version"] = api_version
payload = json.loads(body)
- url, payload = convert_to_azure_payload(url, payload)
+ url, payload = convert_to_azure_payload(url, payload, api_version)
body = json.dumps(payload).encode()
- request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
+ request_url = f"{url}/{path}?api-version={api_version}"
else:
headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}"
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index ee6f99fbb5..34910f23ef 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -29,6 +29,7 @@ import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
+from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.documents import Document
from open_webui.models.files import FileModel, Files
@@ -1146,6 +1147,7 @@ def save_docs_to_vector_db(
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
+ docs = text_splitter.split_documents(docs)
elif request.app.state.config.TEXT_SPLITTER == "token":
log.info(
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
@@ -1158,11 +1160,56 @@ def save_docs_to_vector_db(
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
+ docs = text_splitter.split_documents(docs)
+ elif request.app.state.config.TEXT_SPLITTER == "markdown_header":
+ log.info("Using markdown header text splitter")
+
+ # Define headers to split on - covering most common markdown header levels
+ headers_to_split_on = [
+ ("#", "Header 1"),
+ ("##", "Header 2"),
+ ("###", "Header 3"),
+ ("####", "Header 4"),
+ ("#####", "Header 5"),
+ ("######", "Header 6"),
+ ]
+
+ markdown_splitter = MarkdownHeaderTextSplitter(
+ headers_to_split_on=headers_to_split_on,
+ strip_headers=False, # Keep headers in content for context
+ )
+
+ md_split_docs = []
+ for doc in docs:
+ md_header_splits = markdown_splitter.split_text(doc.page_content)
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=request.app.state.config.CHUNK_SIZE,
+ chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
+ add_start_index=True,
+ )
+ md_header_splits = text_splitter.split_documents(md_header_splits)
+
+ # Convert back to Document objects, preserving original metadata
+ for split_chunk in md_header_splits:
+ headings_list = []
+ # Extract header values in order based on headers_to_split_on
+ for _, header_meta_key_name in headers_to_split_on:
+ if header_meta_key_name in split_chunk.metadata:
+ headings_list.append(
+ split_chunk.metadata[header_meta_key_name]
+ )
+
+ md_split_docs.append(
+ Document(
+ page_content=split_chunk.page_content,
+ metadata={**doc.metadata, "headings": headings_list},
+ )
+ )
+
+ docs = md_split_docs
else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
- docs = text_splitter.split_documents(docs)
-
if len(docs) == 0:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@@ -1747,6 +1794,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
+ elif engine == "exa":
+ if request.app.state.config.EXA_API_KEY:
+ return search_exa(
+ request.app.state.config.EXA_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ else:
+ raise Exception("No EXA_API_KEY found in environment variables")
elif engine == "searchapi":
if request.app.state.config.SEARCHAPI_API_KEY:
return search_searchapi(
@@ -1784,6 +1841,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
+ elif engine == "exa":
+ return search_exa(
+ request.app.state.config.EXA_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py
index f726368eba..41415bff04 100644
--- a/backend/open_webui/routers/tools.py
+++ b/backend/open_webui/routers/tools.py
@@ -153,7 +153,7 @@ async def load_tool_from_url(
)
try:
- async with aiohttp.ClientSession() as session:
+ async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py
index 35e40dccb2..5e7988a226 100644
--- a/backend/open_webui/socket/main.py
+++ b/backend/open_webui/socket/main.py
@@ -1,13 +1,18 @@
import asyncio
+import random
+
import socketio
import logging
import sys
import time
+from typing import Dict, Set
from redis import asyncio as aioredis
+import pycrdt as Y
from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats
+from open_webui.models.notes import Notes, NoteUpdateForm
from open_webui.utils.redis import (
get_sentinels_from_env,
get_sentinel_url_from_env,
@@ -23,6 +28,10 @@ from open_webui.env import (
)
from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock
+from open_webui.tasks import create_task, stop_item_tasks
+from open_webui.utils.redis import get_redis_connection
+from open_webui.utils.access_control import has_access, get_users_with_access
+
from open_webui.env import (
GLOBAL_LOG_LEVEL,
@@ -35,6 +44,14 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
+REDIS = get_redis_connection(
+ redis_url=WEBSOCKET_REDIS_URL,
+ redis_sentinels=get_sentinels_from_env(
+ WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
+ ),
+ async_mode=True,
+)
+
if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS:
mgr = socketio.AsyncRedisManager(
@@ -88,6 +105,10 @@ if WEBSOCKET_MANAGER == "redis":
redis_sentinels=redis_sentinels,
)
+ # TODO: Implement Yjs document management with Redis
+ DOCUMENTS = {}
+ DOCUMENT_USERS = {}
+
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
@@ -101,14 +122,33 @@ else:
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
+
+ DOCUMENTS = {} # document_id -> Y.YDoc instance
+ DOCUMENT_USERS = {} # document_id -> set of user sids
aquire_func = release_func = renew_func = lambda: True
async def periodic_usage_pool_cleanup():
- if not aquire_func():
- log.debug("Usage pool cleanup lock already exists. Not running it.")
- return
- log.debug("Running periodic_usage_pool_cleanup")
+ max_retries = 2
+ retry_delay = random.uniform(
+ WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
+ )
+ for attempt in range(max_retries + 1):
+ if aquire_func():
+ break
+ else:
+ if attempt < max_retries:
+ log.debug(
+ f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
+ )
+ await asyncio.sleep(retry_delay)
+ else:
+ log.warning(
+ "Failed to acquire cleanup lock after retries. Skipping cleanup."
+ )
+ return
+
+ log.debug("Running periodic_cleanup")
try:
while True:
if not renew_func():
@@ -298,6 +338,217 @@ async def channel_events(sid, data):
)
+@sio.on("yjs:document:join")
+async def yjs_document_join(sid, data):
+ """Handle user joining a document"""
+ user = SESSION_POOL.get(sid)
+
+ try:
+ document_id = data["document_id"]
+
+ if document_id.startswith("note:"):
+ note_id = document_id.split(":")[1]
+ note = Notes.get_note_by_id(note_id)
+ if not note:
+ log.error(f"Note {note_id} not found")
+ return
+
+ if (
+ user.get("role") != "admin"
+ and user.get("id") != note.user_id
+ and not has_access(
+ user.get("id"), type="read", access_control=note.access_control
+ )
+ ):
+ log.error(
+ f"User {user.get('id')} does not have access to note {note_id}"
+ )
+ return
+
+ user_id = data.get("user_id", sid)
+ user_name = data.get("user_name", "Anonymous")
+ user_color = data.get("user_color", "#000000")
+
+ log.info(f"User {user_id} joining document {document_id}")
+
+ # Initialize document if it doesn't exist
+ if document_id not in DOCUMENTS:
+ DOCUMENTS[document_id] = {
+ "ydoc": Y.Doc(), # Create actual Yjs document
+ "users": set(),
+ }
+ DOCUMENT_USERS[document_id] = set()
+
+ # Add user to document
+ DOCUMENTS[document_id]["users"].add(sid)
+ DOCUMENT_USERS[document_id].add(sid)
+
+ # Join Socket.IO room
+ await sio.enter_room(sid, f"doc_{document_id}")
+
+ # Send current document state as a proper Yjs update
+ ydoc = DOCUMENTS[document_id]["ydoc"]
+
+ # Encode the entire document state as an update
+ state_update = ydoc.get_update()
+ await sio.emit(
+ "yjs:document:state",
+ {
+ "document_id": document_id,
+ "state": list(state_update), # Convert bytes to list for JSON
+ },
+ room=sid,
+ )
+
+ # Notify other users about the new user
+ await sio.emit(
+ "yjs:user:joined",
+ {
+ "document_id": document_id,
+ "user_id": user_id,
+ "user_name": user_name,
+ "user_color": user_color,
+ },
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ log.info(f"User {user_id} successfully joined document {document_id}")
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_join: {e}")
+ await sio.emit("error", {"message": "Failed to join document"}, room=sid)
+
+
+async def document_save_handler(document_id, data, user):
+ if document_id.startswith("note:"):
+ note_id = document_id.split(":")[1]
+ note = Notes.get_note_by_id(note_id)
+ if not note:
+ log.error(f"Note {note_id} not found")
+ return
+
+ if (
+ user.get("role") != "admin"
+ and user.get("id") != note.user_id
+ and not has_access(
+ user.get("id"), type="read", access_control=note.access_control
+ )
+ ):
+ log.error(f"User {user.get('id')} does not have access to note {note_id}")
+ return
+
+ Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
+
+
+@sio.on("yjs:document:update")
+async def yjs_document_update(sid, data):
+ """Handle Yjs document updates"""
+ try:
+ document_id = data["document_id"]
+ try:
+ await stop_item_tasks(REDIS, document_id)
+ except:
+ pass
+
+ user_id = data.get("user_id", sid)
+
+ update = data["update"] # List of bytes from frontend
+
+ if document_id not in DOCUMENTS:
+ log.warning(f"Document {document_id} not found")
+ return
+
+ # Apply the update to the server's Yjs document
+ ydoc = DOCUMENTS[document_id]["ydoc"]
+ update_bytes = bytes(update)
+
+ try:
+ ydoc.apply_update(update_bytes)
+ except Exception as e:
+ log.error(f"Failed to apply Yjs update: {e}")
+ return
+
+ # Broadcast update to all other users in the document
+ await sio.emit(
+ "yjs:document:update",
+ {
+ "document_id": document_id,
+ "user_id": user_id,
+ "update": update,
+ "socket_id": sid, # Add socket_id to match frontend filtering
+ },
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ async def debounced_save():
+ await asyncio.sleep(0.5)
+ await document_save_handler(
+ document_id, data.get("data", {}), SESSION_POOL.get(sid)
+ )
+
+ await create_task(REDIS, debounced_save(), document_id)
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_update: {e}")
+
+
+@sio.on("yjs:document:leave")
+async def yjs_document_leave(sid, data):
+ """Handle user leaving a document"""
+ try:
+ document_id = data["document_id"]
+ user_id = data.get("user_id", sid)
+
+ log.info(f"User {user_id} leaving document {document_id}")
+
+ if document_id in DOCUMENTS:
+ DOCUMENTS[document_id]["users"].discard(sid)
+
+ if document_id in DOCUMENT_USERS:
+ DOCUMENT_USERS[document_id].discard(sid)
+
+ # Leave Socket.IO room
+ await sio.leave_room(sid, f"doc_{document_id}")
+
+ # Notify other users
+ await sio.emit(
+ "yjs:user:left",
+ {"document_id": document_id, "user_id": user_id},
+ room=f"doc_{document_id}",
+ )
+
+ if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]:
+ # If no users left, clean up the document
+ log.info(f"Cleaning up document {document_id} as no users are left")
+ del DOCUMENTS[document_id]
+ del DOCUMENT_USERS[document_id]
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_leave: {e}")
+
+
+@sio.on("yjs:awareness:update")
+async def yjs_awareness_update(sid, data):
+ """Handle awareness updates (cursors, selections, etc.)"""
+ try:
+ document_id = data["document_id"]
+ user_id = data.get("user_id", sid)
+ update = data["update"]
+
+ # Broadcast awareness update to all other users in the document
+ await sio.emit(
+ "yjs:awareness:update",
+ {"document_id": document_id, "user_id": user_id, "update": update},
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ except Exception as e:
+ log.error(f"Error in yjs_awareness_update: {e}")
+
+
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py
index 2d3955f0a2..a4132d9cf6 100644
--- a/backend/open_webui/tasks.py
+++ b/backend/open_webui/tasks.py
@@ -3,25 +3,27 @@ import asyncio
from typing import Dict
from uuid import uuid4
import json
+import logging
from redis.asyncio import Redis
from fastapi import Request
from typing import Dict, List, Optional
+from open_webui.env import SRC_LOG_LEVELS
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
-chat_tasks = {}
+item_tasks = {}
REDIS_TASKS_KEY = "open-webui:tasks"
-REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
+REDIS_ITEM_TASKS_KEY = "open-webui:tasks:item"
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
-def is_redis(request: Request) -> bool:
- # Called everywhere a request is available to check Redis
- return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
-
-
async def redis_task_command_listener(app):
redis: Redis = app.state.redis
pubsub = redis.pubsub()
@@ -38,7 +40,7 @@ async def redis_task_command_listener(app):
if local_task:
local_task.cancel()
except Exception as e:
- print(f"Error handling distributed task command: {e}")
+ log.exception(f"Error handling distributed task command: {e}")
### ------------------------------
@@ -46,21 +48,21 @@ async def redis_task_command_listener(app):
### ------------------------------
-async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+async def redis_save_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline()
- pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
- if chat_id:
- pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
+ pipe.hset(REDIS_TASKS_KEY, task_id, item_id or "")
+ if item_id:
+ pipe.sadd(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
await pipe.execute()
-async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+async def redis_cleanup_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id)
- if chat_id:
- pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
- if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
- pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
+ if item_id:
+ pipe.srem(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
+ if (await pipe.scard(f"{REDIS_ITEM_TASKS_KEY}:{item_id}").execute())[-1] == 0:
+ pipe.delete(f"{REDIS_ITEM_TASKS_KEY}:{item_id}") # Remove if empty set
await pipe.execute()
@@ -68,31 +70,31 @@ async def redis_list_tasks(redis: Redis) -> List[str]:
return list(await redis.hkeys(REDIS_TASKS_KEY))
-async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
- return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
+async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]:
+ return list(await redis.smembers(f"{REDIS_ITEM_TASKS_KEY}:{item_id}"))
async def redis_send_command(redis: Redis, command: dict):
await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
-async def cleanup_task(request, task_id: str, id=None):
+async def cleanup_task(redis, task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
- if is_redis(request):
- await redis_cleanup_task(request.app.state.redis, task_id, id)
+ if redis:
+ await redis_cleanup_task(redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists
- # If an ID is provided, remove the task from the chat_tasks dictionary
- if id and task_id in chat_tasks.get(id, []):
- chat_tasks[id].remove(task_id)
- if not chat_tasks[id]: # If no tasks left for this ID, remove the entry
- chat_tasks.pop(id, None)
+ # If an ID is provided, remove the task from the item_tasks dictionary
+ if id and task_id in item_tasks.get(id, []):
+ item_tasks[id].remove(task_id)
+ if not item_tasks[id]: # If no tasks left for this ID, remove the entry
+ item_tasks.pop(id, None)
-async def create_task(request, coroutine, id=None):
+async def create_task(redis, coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -101,48 +103,48 @@ async def create_task(request, coroutine, id=None):
# Add a done callback for cleanup
task.add_done_callback(
- lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
+ lambda t: asyncio.create_task(cleanup_task(redis, task_id, id))
)
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
- if chat_tasks.get(id):
- chat_tasks[id].append(task_id)
+ if item_tasks.get(id):
+ item_tasks[id].append(task_id)
else:
- chat_tasks[id] = [task_id]
+ item_tasks[id] = [task_id]
- if is_redis(request):
- await redis_save_task(request.app.state.redis, task_id, id)
+ if redis:
+ await redis_save_task(redis, task_id, id)
return task_id, task
-async def list_tasks(request):
+async def list_tasks(redis):
"""
List all currently active task IDs.
"""
- if is_redis(request):
- return await redis_list_tasks(request.app.state.redis)
+ if redis:
+ return await redis_list_tasks(redis)
return list(tasks.keys())
-async def list_task_ids_by_chat_id(request, id):
+async def list_task_ids_by_item_id(redis, id):
"""
List all tasks associated with a specific ID.
"""
- if is_redis(request):
- return await redis_list_chat_tasks(request.app.state.redis, id)
- return chat_tasks.get(id, [])
+ if redis:
+ return await redis_list_item_tasks(redis, id)
+ return item_tasks.get(id, [])
-async def stop_task(request, task_id: str):
+async def stop_task(redis, task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
- if is_redis(request):
+ if redis:
# PUBSUB: All instances check if they have this task, and stop if so.
await redis_send_command(
- request.app.state.redis,
+ redis,
{
"action": "stop",
"task_id": task_id,
@@ -151,7 +153,7 @@ async def stop_task(request, task_id: str):
# Optionally check if task_id still in Redis a few moments later for feedback?
return {"status": True, "message": f"Stop signal sent for {task_id}"}
- task = tasks.get(task_id)
+ task = tasks.pop(task_id)
if not task:
raise ValueError(f"Task with ID {task_id} not found.")
@@ -160,7 +162,22 @@ async def stop_task(request, task_id: str):
await task # Wait for the task to handle the cancellation
except asyncio.CancelledError:
# Task successfully canceled
- tasks.pop(task_id, None) # Remove it from the dictionary
return {"status": True, "message": f"Task {task_id} successfully stopped."}
return {"status": False, "message": f"Failed to stop task {task_id}."}
+
+
+async def stop_item_tasks(redis: Redis, item_id: str):
+ """
+ Stop all tasks associated with a specific item ID.
+ """
+ task_ids = await list_task_ids_by_item_id(redis, item_id)
+ if not task_ids:
+ return {"status": True, "message": f"No tasks found for item {item_id}."}
+
+ for task_id in task_ids:
+ result = await stop_task(redis, task_id)
+ if not result["status"]:
+ return result # Return the first failure
+
+ return {"status": True, "message": f"All tasks for item {item_id} stopped."}
diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py
index 9befaf2a91..3262c803f3 100644
--- a/backend/open_webui/utils/auth.py
+++ b/backend/open_webui/utils/auth.py
@@ -74,31 +74,37 @@ def override_static(path: str, content: str):
def get_license_data(app, key):
- if key:
- try:
- res = requests.post(
- "https://api.openwebui.com/api/v1/license/",
- json={"key": key, "version": "1"},
- timeout=5,
+ def handler(u):
+ res = requests.post(
+ f"{u}/api/v1/license/",
+ json={"key": key, "version": "1"},
+ timeout=5,
+ )
+
+ if getattr(res, "ok", False):
+ payload = getattr(res, "json", lambda: {})()
+ for k, v in payload.items():
+ if k == "resources":
+ for p, c in v.items():
+ globals().get("override_static", lambda a, b: None)(p, c)
+ elif k == "count":
+ setattr(app.state, "USER_COUNT", v)
+ elif k == "name":
+ setattr(app.state, "WEBUI_NAME", v)
+ elif k == "metadata":
+ setattr(app.state, "LICENSE_METADATA", v)
+ return True
+ else:
+ log.error(
+ f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
- if getattr(res, "ok", False):
- payload = getattr(res, "json", lambda: {})()
- for k, v in payload.items():
- if k == "resources":
- for p, c in v.items():
- globals().get("override_static", lambda a, b: None)(p, c)
- elif k == "count":
- setattr(app.state, "USER_COUNT", v)
- elif k == "name":
- setattr(app.state, "WEBUI_NAME", v)
- elif k == "metadata":
- setattr(app.state, "LICENSE_METADATA", v)
- return True
- else:
- log.error(
- f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
- )
+ if key:
+ us = ["https://api.openwebui.com", "https://licenses.api.openwebui.com"]
+ try:
+ for u in us:
+ if handler(u):
+ return True
except Exception as ex:
log.exception(f"License: Uncaught Exception: {ex}")
return False
diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py
index 268c910e3e..83483f391b 100644
--- a/backend/open_webui/utils/chat.py
+++ b/backend/open_webui/utils/chat.py
@@ -419,7 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
params[key] = value
if "__user__" in sig.parameters:
- __user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
+ __user__ = user.model_dump() if isinstance(user, UserModel) else {}
try:
if hasattr(function_module, "UserValves"):
diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py
index 2557610060..a21df2756b 100644
--- a/backend/open_webui/utils/logger.py
+++ b/backend/open_webui/utils/logger.py
@@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
from loguru import logger
+
from open_webui.env import (
+ AUDIT_UVICORN_LOGGER_NAMES,
AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOG_LEVEL,
AUDIT_LOGS_FILE_PATH,
@@ -128,11 +130,13 @@ def start_logger():
logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
)
+
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = []
- for uvicorn_logger_name in ["uvicorn.access"]:
+
+ for uvicorn_logger_name in AUDIT_UVICORN_LOGGER_NAMES:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()]
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index b1e69db264..003e97e84c 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -23,6 +23,7 @@ from starlette.responses import Response, StreamingResponse
from open_webui.models.chats import Chats
+from open_webui.models.folders import Folders
from open_webui.models.users import Users
from open_webui.socket.main import (
get_event_call,
@@ -56,7 +57,7 @@ from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
-from open_webui.retrieval.utils import get_sources_from_files
+from open_webui.retrieval.utils import get_sources_from_items
from open_webui.utils.chat import generate_chat_completion
@@ -248,30 +249,28 @@ async def chat_completion_tools_handler(
if tool_id
else f"{tool_function_name}"
)
- if tool.get("metadata", {}).get("citation", False) or tool.get(
- "direct", False
- ):
- # Citation is enabled for this tool
- sources.append(
- {
- "source": {
- "name": (f"TOOL:{tool_name}"),
- },
- "document": [tool_result],
- "metadata": [
- {
- "source": (f"TOOL:{tool_name}"),
- "parameters": tool_function_params,
- }
- ],
- }
- )
- else:
- # Citation is not enabled for this tool
- body["messages"] = add_or_update_user_message(
- f"\nTool `{tool_name}` Output: {tool_result}",
- body["messages"],
- )
+
+ # Citation is enabled for this tool
+ sources.append(
+ {
+ "source": {
+ "name": (f"TOOL:{tool_name}"),
+ },
+ "document": [tool_result],
+ "metadata": [
+ {
+ "source": (f"TOOL:{tool_name}"),
+ "parameters": tool_function_params,
+ }
+ ],
+ "tool_result": True,
+ }
+ )
+ # Citation is not enabled for this tool
+ body["messages"] = add_or_update_user_message(
+ f"\nTool `{tool_name}` Output: {tool_result}",
+ body["messages"],
+ )
if (
tools[tool_function_name]
@@ -640,14 +639,14 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])]
try:
- # Offload get_sources_from_files to a separate thread
+ # Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
sources = await loop.run_in_executor(
executor,
- lambda: get_sources_from_files(
+ lambda: get_sources_from_items(
request=request,
- files=files,
+ items=files,
queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user
@@ -659,6 +658,7 @@ async def chat_completion_files_handler(
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT,
+ user=user,
),
)
except Exception as e:
@@ -718,6 +718,10 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, user, metadata, model):
+ # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
+ # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
+ # -> Chat Files
+
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@@ -752,6 +756,26 @@ async def process_chat_payload(request, form_data, user, metadata, model):
events = []
sources = []
+ # Folder "Project" handling
+ # Check if the request has chat_id and is inside of a folder
+ chat_id = metadata.get("chat_id", None)
+ if chat_id and user:
+ chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
+ if chat and chat.folder_id:
+ folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
+
+ if folder and folder.data:
+ if "system_prompt" in folder.data:
+ form_data["messages"] = add_or_update_system_message(
+ folder.data["system_prompt"], form_data["messages"]
+ )
+ if "files" in folder.data:
+ form_data["files"] = [
+ *folder.data["files"],
+ *form_data.get("files", []),
+ ]
+
+ # Model "Knowledge" handling
user_message = get_last_user_message(form_data["messages"])
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
@@ -804,7 +828,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise e
try:
-
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
@@ -912,7 +935,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
-
except Exception as e:
log.exception(e)
@@ -925,55 +947,59 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
- citation_idx = {}
+ citation_idx_map = {}
+
for source in sources:
- if "document" in source:
- for doc_context, doc_meta in zip(
+ is_tool_result = source.get("tool_result", False)
+
+ if "document" in source and not is_tool_result:
+ for document_text, document_metadata in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
- citation_id = (
- doc_meta.get("source", None)
+ source_id = (
+ document_metadata.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
- if citation_id not in citation_idx:
- citation_idx[citation_id] = len(citation_idx) + 1
+
+ if source_id not in citation_idx_map:
+ citation_idx_map[source_id] = len(citation_idx_map) + 1
+
context_string += (
- f'Tool Executed
\nTool Executed
\nExecuting...
\n
-