diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index a226dd6a55..26fbfa84c5 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -19,6 +19,7 @@ from open_webui.env import (
DATABASE_URL,
ENV,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
FRONTEND_BUILD_DIR,
@@ -211,11 +212,16 @@ class PersistentConfig(Generic[T]):
class AppConfig:
_state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None
+ _redis_key_prefix: str
def __init__(
- self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = []
+ self,
+ redis_url: Optional[str] = None,
+ redis_sentinels: Optional[list] = [],
+ redis_key_prefix: str = "open-webui",
):
super().__setattr__("_state", {})
+ super().__setattr__("_redis_key_prefix", redis_key_prefix)
if redis_url:
super().__setattr__(
"_redis",
@@ -230,7 +236,7 @@ class AppConfig:
self._state[key].save()
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
self._redis.set(redis_key, json.dumps(self._state[key].value))
def __getattr__(self, key):
@@ -239,7 +245,7 @@ class AppConfig:
# If Redis is available, check for an updated value
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
redis_value = self._redis.get(redis_key)
if redis_value is not None:
@@ -431,6 +437,12 @@ OAUTH_SCOPES = PersistentConfig(
os.environ.get("OAUTH_SCOPES", "openid email profile"),
)
+OAUTH_TIMEOUT = PersistentConfig(
+ "OAUTH_TIMEOUT",
+ "oauth.oidc.oauth_timeout",
+ os.environ.get("OAUTH_TIMEOUT", ""),
+)
+
OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig(
"OAUTH_CODE_CHALLENGE_METHOD",
"oauth.oidc.code_challenge_method",
@@ -540,7 +552,14 @@ def load_oauth_providers():
client_id=GOOGLE_CLIENT_ID.value,
client_secret=GOOGLE_CLIENT_SECRET.value,
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
- client_kwargs={"scope": GOOGLE_OAUTH_SCOPE.value},
+ client_kwargs={
+ "scope": GOOGLE_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GOOGLE_REDIRECT_URI.value,
)
@@ -563,6 +582,11 @@ def load_oauth_providers():
server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
client_kwargs={
"scope": MICROSOFT_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
},
redirect_uri=MICROSOFT_REDIRECT_URI.value,
)
@@ -584,7 +608,14 @@ def load_oauth_providers():
authorize_url="https://github.com/login/oauth/authorize",
api_base_url="https://api.github.com",
userinfo_endpoint="https://api.github.com/user",
- client_kwargs={"scope": GITHUB_CLIENT_SCOPE.value},
+ client_kwargs={
+ "scope": GITHUB_CLIENT_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value,
)
@@ -603,6 +634,9 @@ def load_oauth_providers():
def oidc_oauth_register(client):
client_kwargs = {
"scope": OAUTH_SCOPES.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}
+ ),
}
if (
@@ -895,6 +929,18 @@ except Exception:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
+
+####################################
+# MODELS
+####################################
+
+ENABLE_BASE_MODELS_CACHE = PersistentConfig(
+ "ENABLE_BASE_MODELS_CACHE",
+ "models.base_models_cache",
+ os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true",
+)
+
+
####################################
# TOOL_SERVERS
####################################
@@ -1799,6 +1845,7 @@ QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
ENABLE_QDRANT_MULTITENANCY_MODE = (
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
)
+QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index cd0136a5c8..4db919121a 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -267,6 +267,30 @@ else:
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
+DATABASE_TYPE = os.environ.get("DATABASE_TYPE")
+DATABASE_USER = os.environ.get("DATABASE_USER")
+DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD")
+
+DATABASE_CRED = ""
+if DATABASE_USER:
+ DATABASE_CRED += f"{DATABASE_USER}"
+if DATABASE_PASSWORD:
+ DATABASE_CRED += f":{DATABASE_PASSWORD}"
+if DATABASE_CRED:
+ DATABASE_CRED += "@"
+
+
+DB_VARS = {
+ "db_type": DATABASE_TYPE,
+ "db_cred": DATABASE_CRED,
+ "db_host": os.environ.get("DATABASE_HOST"),
+ "db_port": os.environ.get("DATABASE_PORT"),
+ "db_name": os.environ.get("DATABASE_NAME"),
+}
+
+if all(DB_VARS.values()):
+ DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
+
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
@@ -324,6 +348,7 @@ ENABLE_REALTIME_CHAT_SAVE = (
####################################
REDIS_URL = os.environ.get("REDIS_URL", "")
+REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
@@ -399,10 +424,29 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
)
+####################################
+# MODELS
+####################################
+
+MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
+if MODELS_CACHE_TTL == "":
+ MODELS_CACHE_TTL = None
+else:
+ try:
+ MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
+ except Exception:
+ MODELS_CACHE_TTL = 1
+
+
+####################################
+# WEBSOCKET SUPPORT
+####################################
+
ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
)
+
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
@@ -509,11 +553,14 @@ else:
# OFFLINE_MODE
####################################
+ENABLE_VERSION_UPDATE_CHECK = (
+ os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true"
+)
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
-
+ ENABLE_VERSION_UPDATE_CHECK = False
####################################
# AUDIT LOGGING
@@ -522,6 +569,14 @@ if OFFLINE_MODE:
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
+
+# Comma separated list of logger names to use for audit logging
+# Default is "uvicorn.access" which is the access log for Uvicorn
+# You can add more logger names to this list if you want to capture more logs
+AUDIT_UVICORN_LOGGER_NAMES = os.getenv(
+ "AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access"
+).split(",")
+
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
@@ -559,6 +614,12 @@ OTEL_TRACES_SAMPLER = os.environ.get(
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
+
+OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
+ "OTEL_OTLP_SPAN_EXPORTER", "grpc"
+).lower() # grpc or http
+
+
####################################
# TOOLS/FUNCTIONS PIP OPTIONS
####################################
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 8e37d9e530..7f6a172f85 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -36,7 +36,6 @@ from fastapi import (
applications,
BackgroundTasks,
)
-
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
+from starlette.datastructures import Headers
from open_webui.utils import logger
@@ -116,6 +116,8 @@ from open_webui.config import (
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
+ # Model list
+ ENABLE_BASE_MODELS_CACHE,
# Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE,
# Tool Server Configs
@@ -396,6 +398,7 @@ from open_webui.env import (
AUDIT_LOG_LEVEL,
CHANGELOG,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
GLOBAL_LOG_LEVEL,
@@ -415,7 +418,7 @@ from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
- OFFLINE_MODE,
+ ENABLE_VERSION_UPDATE_CHECK,
ENABLE_OTEL,
EXTERNAL_PWA_MANIFEST_URL,
AIOHTTP_CLIENT_SESSION_SSL,
@@ -534,6 +537,27 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup())
+ if app.state.config.ENABLE_BASE_MODELS_CACHE:
+ await get_all_models(
+ Request(
+ # Creating a mock request object to pass to get_all_models
+ {
+ "type": "http",
+ "asgi.version": "3.0",
+ "asgi.spec_version": "2.0",
+ "method": "GET",
+ "path": "/internal",
+ "query_string": b"",
+ "headers": Headers({}).raw,
+ "client": ("127.0.0.1", 12345),
+ "server": ("127.0.0.1", 80),
+ "scheme": "http",
+ "app": app,
+ }
+ ),
+ None,
+ )
+
yield
if hasattr(app.state, "redis_task_command_listener"):
@@ -554,6 +578,7 @@ app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
+ redis_key_prefix=REDIS_KEY_PREFIX,
)
app.state.redis = None
@@ -616,6 +641,15 @@ app.state.TOOL_SERVERS = []
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
+########################################
+#
+# MODELS
+#
+########################################
+
+app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE
+app.state.BASE_MODELS = []
+
########################################
#
# WEBUI
@@ -1191,7 +1225,9 @@ if audit_level != AuditLevel.NONE:
@app.get("/api/models")
-async def get_models(request: Request, user=Depends(get_verified_user)):
+async def get_models(
+ request: Request, refresh: bool = False, user=Depends(get_verified_user)
+):
def get_filtered_models(models, user):
filtered_models = []
for model in models:
@@ -1215,7 +1251,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
- all_models = await get_all_models(request, user=user)
+ all_models = await get_all_models(request, refresh=refresh, user=user)
models = []
for model in all_models:
@@ -1471,7 +1507,7 @@ async def list_tasks_by_chat_id_endpoint(
task_ids = await list_task_ids_by_chat_id(request, chat_id)
- print(f"Task IDs for chat {chat_id}: {task_ids}")
+ log.debug(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}
@@ -1524,6 +1560,7 @@ async def get_app_config(request: Request):
"enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
+ "enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK,
**(
{
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
@@ -1629,9 +1666,9 @@ async def get_app_version():
@app.get("/api/version/updates")
async def get_app_latest_release_version(user=Depends(get_verified_user)):
- if OFFLINE_MODE:
+ if not ENABLE_VERSION_UPDATE_CHECK:
log.debug(
- f"Offline mode is enabled, returning current version as latest version"
+ f"Version update check is disabled, returning current version as latest version"
)
return {"current": VERSION, "latest": VERSION}
try:
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index 0ac53a0233..c55aec6aaf 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
+from sqlalchemy.sql.expression import bindparam
####################
# Chat DB Schema
@@ -72,6 +73,8 @@ class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
folder_id: Optional[str] = None
+ created_at: Optional[int] = None
+ updated_at: Optional[int] = None
class ChatTitleMessagesForm(BaseModel):
@@ -147,8 +150,16 @@ class ChatTable:
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
- "created_at": int(time.time()),
- "updated_at": int(time.time()),
+ "created_at": (
+ form_data.created_at
+ if form_data.created_at
+ else int(time.time())
+ ),
+ "updated_at": (
+ form_data.updated_at
+ if form_data.updated_at
+ else int(time.time())
+ ),
}
)
@@ -232,6 +243,10 @@ class ChatTable:
if chat is None:
return None
+ # Sanitize message content for null characters before upserting
+ if isinstance(message.get("content"), str):
+ message["content"] = message["content"].replace("\x00", "")
+
chat = chat.chat
history = chat.get("history", {})
@@ -580,7 +595,7 @@ class ChatTable:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
- search_text = search_text.lower().strip()
+ search_text = search_text.replace("\u0000", "").lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(
@@ -614,21 +629,19 @@ class ChatTable:
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
+ sqlite_content_sql = (
+ "EXISTS ("
+ " SELECT 1 "
+ " FROM json_each(Chat.chat, '$.messages') AS message "
+ " WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'"
+ ")"
+ )
+ sqlite_content_clause = text(sqlite_content_sql)
query = query.filter(
- (
- Chat.title.ilike(
- f"%{search_text}%"
- ) # Case-insensitive search in title
- | text(
- """
- EXISTS (
- SELECT 1
- FROM json_each(Chat.chat, '$.messages') AS message
- WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
- )
- """
- )
- ).params(search_text=search_text)
+ or_(
+ Chat.title.ilike(bindparam('title_key')),
+ sqlite_content_clause
+ ).params(title_key=f"%{search_text}%", content_key=search_text)
)
# Check if there are any tags to filter, it should have all the tags
@@ -663,21 +676,19 @@ class ChatTable:
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
+ postgres_content_sql = (
+ "EXISTS ("
+ " SELECT 1 "
+ " FROM json_array_elements(Chat.chat->'messages') AS message "
+ " WHERE LOWER(message->>'content') LIKE '%' || :content_key || '%'"
+ ")"
+ )
+ postgres_content_clause = text(postgres_content_sql)
query = query.filter(
- (
- Chat.title.ilike(
- f"%{search_text}%"
- ) # Case-insensitive search in title
- | text(
- """
- EXISTS (
- SELECT 1
- FROM json_array_elements(Chat.chat->'messages') AS message
- WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
- )
- """
- )
- ).params(search_text=search_text)
+ or_(
+ Chat.title.ilike(bindparam('title_key')),
+ postgres_content_clause
+ ).params(title_key=f"%{search_text}%", content_key=search_text)
)
# Check if there are any tags to filter, it should have all the tags
diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py
index b00e9d7ce5..b7f2622f5e 100644
--- a/backend/open_webui/retrieval/loaders/mistral.py
+++ b/backend/open_webui/retrieval/loaders/mistral.py
@@ -507,6 +507,7 @@ class MistralLoader:
timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually
+ trust_env=True,
) as session:
yield session
diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py
index 00dd683063..0a0f0dabab 100644
--- a/backend/open_webui/retrieval/utils.py
+++ b/backend/open_webui/retrieval/utils.py
@@ -460,20 +460,19 @@ def get_sources_from_files(
)
extracted_collections = []
- relevant_contexts = []
+ query_results = []
for file in files:
-
- context = None
+ query_result = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
- context = {
+ query_result = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
- context = {
+ query_result = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
@@ -500,7 +499,7 @@ def get_sources_from_files(
}
)
- context = {
+ query_result = {
"documents": [documents],
"metadatas": [metadatas],
}
@@ -508,7 +507,7 @@ def get_sources_from_files(
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
- context = {
+ query_result = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
@@ -521,7 +520,7 @@ def get_sources_from_files(
],
}
elif file.get("file").get("data"):
- context = {
+ query_result = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [
[file.get("file").get("data", {}).get("metadata", {})]
@@ -549,19 +548,27 @@ def get_sources_from_files(
if full_context:
try:
- context = get_all_items_from_collections(collection_names)
+ query_result = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
- context = None
+ query_result = None
if file.get("type") == "text":
- context = file["content"]
+ # Not sure when this is used, but it seems to be a fallback
+ query_result = {
+ "documents": [
+ [file.get("file").get("data", {}).get("content")]
+ ],
+ "metadatas": [
+ [file.get("file").get("data", {}).get("meta", {})]
+ ],
+ }
else:
if hybrid_search:
try:
- context = query_collection_with_hybrid_search(
+ query_result = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
@@ -577,8 +584,8 @@ def get_sources_from_files(
" non hybrid search as fallback."
)
- if (not hybrid_search) or (context is None):
- context = query_collection(
+ if (not hybrid_search) or (query_result is None):
+ query_result = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
@@ -589,24 +596,24 @@ def get_sources_from_files(
extracted_collections.extend(collection_names)
- if context:
+ if query_result:
if "data" in file:
del file["data"]
- relevant_contexts.append({**context, "file": file})
+ query_results.append({**query_result, "file": file})
sources = []
- for context in relevant_contexts:
+ for query_result in query_results:
try:
- if "documents" in context:
- if "metadatas" in context:
+ if "documents" in query_result:
+ if "metadatas" in query_result:
source = {
- "source": context["file"],
- "document": context["documents"][0],
- "metadata": context["metadatas"][0],
+ "source": query_result["file"],
+ "document": query_result["documents"][0],
+ "metadata": query_result["metadatas"][0],
}
- if "distances" in context and context["distances"]:
- source["distances"] = context["distances"][0]
+ if "distances" in query_result and query_result["distances"]:
+ source["distances"] = query_result["distances"][0]
sources.append(source)
except Exception as e:
diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py
index 60ef2d906c..7e16df3cfb 100644
--- a/backend/open_webui/retrieval/vector/dbs/opensearch.py
+++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py
@@ -157,10 +157,10 @@ class OpenSearchClient(VectorDBBase):
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
- size = limit if limit else 10
+ size = limit if limit else 10000
try:
result = self.client.search(
@@ -206,6 +206,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
@@ -228,6 +229,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def delete(
self,
@@ -251,11 +253,12 @@ class OpenSearchClient(VectorDBBase):
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py
index dfe2979076..2276e713fc 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py
@@ -18,6 +18,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -29,7 +30,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
@@ -86,6 +87,25 @@ class QdrantClient(VectorDBBase):
),
)
+ # Create payload indexes for efficient filtering
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.hash",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.file_id",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
index 377b036247..df2c4e2431 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
@@ -9,6 +9,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
@@ -31,7 +32,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py
index 3075db990f..7bea575620 100644
--- a/backend/open_webui/retrieval/web/brave.py
+++ b/backend/open_webui/retrieval/web/brave.py
@@ -36,7 +36,9 @@ def search_brave(
return [
SearchResult(
- link=result["url"], title=result.get("title"), snippet=result.get("snippet")
+ link=result["url"],
+ title=result.get("title"),
+ snippet=result.get("description"),
)
for result in results[:count]
]
diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py
index 106f3684a7..d90ebeb5e4 100644
--- a/backend/open_webui/routers/auths.py
+++ b/backend/open_webui/routers/auths.py
@@ -675,7 +675,7 @@ async def signout(request: Request, response: Response):
oauth_id_token = request.cookies.get("oauth_id_token")
if oauth_id_token:
try:
- async with ClientSession() as session:
+ async with ClientSession(trust_env=True) as session:
async with session.get(OPENID_PROVIDER_URL.value) as resp:
if resp.status == 200:
openid_data = await resp.json()
@@ -687,7 +687,7 @@ async def signout(request: Request, response: Response):
status_code=200,
content={
"status": True,
- "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
+ "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}" + (f"&post_logout_redirect_uri={WEBUI_AUTH_SIGNOUT_REDIRECT_URL}" if WEBUI_AUTH_SIGNOUT_REDIRECT_URL else ""),
},
headers=response.headers,
)
diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py
index 29b12ed676..13b6040102 100644
--- a/backend/open_webui/routers/chats.py
+++ b/backend/open_webui/routers/chats.py
@@ -684,8 +684,10 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
- if not has_permission(
- user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ if (user.role != "admin") and (
+ not has_permission(
+ user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ )
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index 44b2ef40cf..a329584ca2 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -39,32 +39,39 @@ async def export_config(user=Depends(get_admin_user)):
############################
-# Direct Connections Config
+# Connections Config
############################
-class DirectConnectionsConfigForm(BaseModel):
+class ConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool
+ ENABLE_BASE_MODELS_CACHE: bool
-@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+@router.get("/connections", response_model=ConnectionsConfigForm)
+async def get_connections_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
-@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def set_direct_connections_config(
+@router.post("/connections", response_model=ConnectionsConfigForm)
+async def set_connections_config(
request: Request,
- form_data: DirectConnectionsConfigForm,
+ form_data: ConnectionsConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS
)
+ request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
+ form_data.ENABLE_BASE_MODELS_CACHE
+ )
+
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py
index 355093335a..96d8215fb3 100644
--- a/backend/open_webui/routers/functions.py
+++ b/backend/open_webui/routers/functions.py
@@ -105,7 +105,7 @@ async def load_function_from_url(
)
try:
- async with aiohttp.ClientSession() as session:
+ async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index fc1c1b9a5a..2832a11577 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -303,10 +303,12 @@ async def update_image_config(
):
set_image_model(request, form_data.MODEL)
- if (form_data.IMAGE_SIZE == "auto" and form_data.MODEL != 'gpt-image-1'):
+ if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1":
raise HTTPException(
status_code=400,
- detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (auto is only allowed with gpt-image-1).")
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(
+ " (auto is only allowed with gpt-image-1)."
+ ),
)
pattern = r"^\d+x\d+$"
@@ -483,7 +485,7 @@ async def image_generations(
# image model other than gpt-image-1, which is warned about on settings save
width, height = (
tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
- if 'x' in request.app.state.config.IMAGE_SIZE
+ if "x" in request.app.state.config.IMAGE_SIZE
else (512, 512)
)
diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py
index 9c1e1fdb00..3887106ad2 100644
--- a/backend/open_webui/routers/ollama.py
+++ b/backend/open_webui/routers/ollama.py
@@ -59,6 +59,7 @@ from open_webui.config import (
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -330,7 +331,7 @@ def merge_ollama_models_lists(model_lists):
return list(merged_models.values())
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index ab35a673fc..a769c9a0c9 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -21,6 +21,7 @@ from open_webui.config import (
CACHE_DIR,
)
from open_webui.env import (
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -386,7 +387,7 @@ async def get_filtered_models(models, user):
return filtered_models
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index 6d888ca990..a851abc2e5 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -1794,6 +1794,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
+ elif engine == "exa":
+ return search_exa(
+ request.app.state.config.EXA_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py
index f726368eba..41415bff04 100644
--- a/backend/open_webui/routers/tools.py
+++ b/backend/open_webui/routers/tools.py
@@ -153,7 +153,7 @@ async def load_tool_from_url(
)
try:
- async with aiohttp.ClientSession() as session:
+ async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py
index 35e40dccb2..96bcbcf1b5 100644
--- a/backend/open_webui/socket/main.py
+++ b/backend/open_webui/socket/main.py
@@ -1,4 +1,6 @@
import asyncio
+import random
+
import socketio
import logging
import sys
@@ -105,10 +107,26 @@ else:
async def periodic_usage_pool_cleanup():
- if not aquire_func():
- log.debug("Usage pool cleanup lock already exists. Not running it.")
- return
- log.debug("Running periodic_usage_pool_cleanup")
+ max_retries = 2
+ retry_delay = random.uniform(
+ WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
+ )
+ for attempt in range(max_retries + 1):
+ if aquire_func():
+ break
+ else:
+ if attempt < max_retries:
+ log.debug(
+ f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
+ )
+ await asyncio.sleep(retry_delay)
+ else:
+ log.warning(
+ "Failed to acquire cleanup lock after retries. Skipping cleanup."
+ )
+ return
+
+ log.debug("Running periodic_cleanup")
try:
while True:
if not renew_func():
diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py
index 2d3955f0a2..8ba66b2f57 100644
--- a/backend/open_webui/tasks.py
+++ b/backend/open_webui/tasks.py
@@ -3,10 +3,17 @@ import asyncio
from typing import Dict
from uuid import uuid4
import json
+import logging
from redis.asyncio import Redis
from fastapi import Request
from typing import Dict, List, Optional
+from open_webui.env import SRC_LOG_LEVELS
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {}
@@ -38,7 +45,7 @@ async def redis_task_command_listener(app):
if local_task:
local_task.cancel()
except Exception as e:
- print(f"Error handling distributed task command: {e}")
+ log.exception(f"Error handling distributed task command: {e}")
### ------------------------------
diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py
index 2557610060..a21df2756b 100644
--- a/backend/open_webui/utils/logger.py
+++ b/backend/open_webui/utils/logger.py
@@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
from loguru import logger
+
from open_webui.env import (
+ AUDIT_UVICORN_LOGGER_NAMES,
AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOG_LEVEL,
AUDIT_LOGS_FILE_PATH,
@@ -128,11 +130,13 @@ def start_logger():
logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
)
+
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = []
- for uvicorn_logger_name in ["uvicorn.access"]:
+
+ for uvicorn_logger_name in AUDIT_UVICORN_LOGGER_NAMES:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()]
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 6bad97b1f4..ff4f901b76 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -248,6 +248,7 @@ async def chat_completion_tools_handler(
if tool_id
else f"{tool_function_name}"
)
+
if tool.get("metadata", {}).get("citation", False) or tool.get(
"direct", False
):
@@ -718,6 +719,10 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, user, metadata, model):
+ # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
+ # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
+ # -> Chat Files
+
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@@ -911,7 +916,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
-
except Exception as e:
log.exception(e)
@@ -924,24 +928,27 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
- citation_idx = {}
+ citation_idx_map = {}
+
for source in sources:
if "document" in source:
- for doc_context, doc_meta in zip(
+ for document_text, document_metadata in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
- citation_id = (
- doc_meta.get("source", None)
+ source_id = (
+ document_metadata.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
- if citation_id not in citation_idx:
- citation_idx[citation_id] = len(citation_idx) + 1
+
+ if source_id not in citation_idx_map:
+ citation_idx_map[source_id] = len(citation_idx_map) + 1
+
context_string += (
- f'
-