diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 898ac1b594..1beed9f21b 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -19,6 +19,7 @@ from open_webui.env import (
DATABASE_URL,
ENV,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
FRONTEND_BUILD_DIR,
@@ -211,11 +212,16 @@ class PersistentConfig(Generic[T]):
class AppConfig:
_state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None
+ _redis_key_prefix: str
def __init__(
- self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = []
+ self,
+ redis_url: Optional[str] = None,
+ redis_sentinels: Optional[list] = [],
+ redis_key_prefix: str = "open-webui",
):
super().__setattr__("_state", {})
+ super().__setattr__("_redis_key_prefix", redis_key_prefix)
if redis_url:
super().__setattr__(
"_redis",
@@ -230,7 +236,7 @@ class AppConfig:
self._state[key].save()
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
self._redis.set(redis_key, json.dumps(self._state[key].value))
def __getattr__(self, key):
@@ -239,7 +245,7 @@ class AppConfig:
# If Redis is available, check for an updated value
if self._redis:
- redis_key = f"open-webui:config:{key}"
+ redis_key = f"{self._redis_key_prefix}:config:{key}"
redis_value = self._redis.get(redis_key)
if redis_value is not None:
@@ -431,6 +437,12 @@ OAUTH_SCOPES = PersistentConfig(
os.environ.get("OAUTH_SCOPES", "openid email profile"),
)
+OAUTH_TIMEOUT = PersistentConfig(
+ "OAUTH_TIMEOUT",
+ "oauth.oidc.oauth_timeout",
+ os.environ.get("OAUTH_TIMEOUT", ""),
+)
+
OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig(
"OAUTH_CODE_CHALLENGE_METHOD",
"oauth.oidc.code_challenge_method",
@@ -540,7 +552,14 @@ def load_oauth_providers():
client_id=GOOGLE_CLIENT_ID.value,
client_secret=GOOGLE_CLIENT_SECRET.value,
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
- client_kwargs={"scope": GOOGLE_OAUTH_SCOPE.value},
+ client_kwargs={
+ "scope": GOOGLE_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GOOGLE_REDIRECT_URI.value,
)
@@ -563,6 +582,11 @@ def load_oauth_providers():
server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
client_kwargs={
"scope": MICROSOFT_OAUTH_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
},
redirect_uri=MICROSOFT_REDIRECT_URI.value,
)
@@ -584,7 +608,14 @@ def load_oauth_providers():
authorize_url="https://github.com/login/oauth/authorize",
api_base_url="https://api.github.com",
userinfo_endpoint="https://api.github.com/user",
- client_kwargs={"scope": GITHUB_CLIENT_SCOPE.value},
+ client_kwargs={
+ "scope": GITHUB_CLIENT_SCOPE.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)}
+ if OAUTH_TIMEOUT.value
+ else {}
+ ),
+ },
redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value,
)
@@ -603,6 +634,9 @@ def load_oauth_providers():
def oidc_oauth_register(client):
client_kwargs = {
"scope": OAUTH_SCOPES.value,
+ **(
+ {"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}
+ ),
}
if (
@@ -895,6 +929,18 @@ except Exception:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
+
+####################################
+# MODELS
+####################################
+
+ENABLE_BASE_MODELS_CACHE = PersistentConfig(
+ "ENABLE_BASE_MODELS_CACHE",
+ "models.base_models_cache",
+ os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true",
+)
+
+
####################################
# TOOL_SERVERS
####################################
@@ -1799,6 +1845,7 @@ QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
ENABLE_QDRANT_MULTITENANCY_MODE = (
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "false").lower() == "true"
)
+QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index 0f7b5611f5..dafb7be13a 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -199,6 +199,7 @@ CHANGELOG = changelog_json
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
+
####################################
# ENABLE_FORWARD_USER_INFO_HEADERS
####################################
@@ -272,15 +273,13 @@ if "postgres://" in DATABASE_URL:
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
-DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
+DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None)
-if DATABASE_POOL_SIZE == "":
- DATABASE_POOL_SIZE = 0
-else:
+if DATABASE_POOL_SIZE != None:
try:
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
except Exception:
- DATABASE_POOL_SIZE = 0
+ DATABASE_POOL_SIZE = None
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
@@ -325,6 +324,7 @@ ENABLE_REALTIME_CHAT_SAVE = (
####################################
REDIS_URL = os.environ.get("REDIS_URL", "")
+REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
@@ -396,10 +396,33 @@ WEBUI_AUTH_COOKIE_SECURE = (
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
+ENABLE_COMPRESSION_MIDDLEWARE = (
+ os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
+)
+
+####################################
+# MODELS
+####################################
+
+MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
+if MODELS_CACHE_TTL == "":
+ MODELS_CACHE_TTL = None
+else:
+ try:
+ MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
+ except Exception:
+ MODELS_CACHE_TTL = 1
+
+
+####################################
+# WEBSOCKET SUPPORT
+####################################
+
ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
)
+
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
@@ -543,6 +566,9 @@ ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() ==
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
+OTEL_EXPORTER_OTLP_INSECURE = (
+ os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
+)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
@@ -550,6 +576,14 @@ OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
+OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
+OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
+
+
+OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
+ "OTEL_OTLP_SPAN_EXPORTER", "grpc"
+).lower() # grpc or http
+
####################################
# TOOLS/FUNCTIONS PIP OPTIONS
diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py
index 840f571cc9..e1ffc1eb27 100644
--- a/backend/open_webui/internal/db.py
+++ b/backend/open_webui/internal/db.py
@@ -62,6 +62,9 @@ def handle_peewee_migration(DATABASE_URL):
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
+ log.warning(
+ "Hint: If your database password contains special characters, you may need to URL-encode it."
+ )
raise
finally:
# Properly closing the database connection
@@ -81,20 +84,23 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
else:
- if DATABASE_POOL_SIZE > 0:
- engine = create_engine(
- SQLALCHEMY_DATABASE_URL,
- pool_size=DATABASE_POOL_SIZE,
- max_overflow=DATABASE_POOL_MAX_OVERFLOW,
- pool_timeout=DATABASE_POOL_TIMEOUT,
- pool_recycle=DATABASE_POOL_RECYCLE,
- pool_pre_ping=True,
- poolclass=QueuePool,
- )
+ if isinstance(DATABASE_POOL_SIZE, int):
+ if DATABASE_POOL_SIZE > 0:
+ engine = create_engine(
+ SQLALCHEMY_DATABASE_URL,
+ pool_size=DATABASE_POOL_SIZE,
+ max_overflow=DATABASE_POOL_MAX_OVERFLOW,
+ pool_timeout=DATABASE_POOL_TIMEOUT,
+ pool_recycle=DATABASE_POOL_RECYCLE,
+ pool_pre_ping=True,
+ poolclass=QueuePool,
+ )
+ else:
+ engine = create_engine(
+ SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
+ )
else:
- engine = create_engine(
- SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
- )
+ engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 544756a6e8..96b04aaed8 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -36,7 +36,6 @@ from fastapi import (
applications,
BackgroundTasks,
)
-
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
+from starlette.datastructures import Headers
from open_webui.utils import logger
@@ -116,6 +116,8 @@ from open_webui.config import (
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
+ # Model list
+ ENABLE_BASE_MODELS_CACHE,
# Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE,
# Tool Server Configs
@@ -396,6 +398,7 @@ from open_webui.env import (
AUDIT_LOG_LEVEL,
CHANGELOG,
REDIS_URL,
+ REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
GLOBAL_LOG_LEVEL,
@@ -411,6 +414,7 @@ from open_webui.env import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
+ ENABLE_COMPRESSION_MIDDLEWARE,
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
@@ -533,6 +537,27 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup())
+ if app.state.config.ENABLE_BASE_MODELS_CACHE:
+ await get_all_models(
+ Request(
+ # Creating a mock request object to pass to get_all_models
+ {
+ "type": "http",
+ "asgi.version": "3.0",
+ "asgi.spec_version": "2.0",
+ "method": "GET",
+ "path": "/internal",
+ "query_string": b"",
+ "headers": Headers({}).raw,
+ "client": ("127.0.0.1", 12345),
+ "server": ("127.0.0.1", 80),
+ "scheme": "http",
+ "app": app,
+ }
+ ),
+ None,
+ )
+
yield
if hasattr(app.state, "redis_task_command_listener"):
@@ -553,6 +578,7 @@ app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
+ redis_key_prefix=REDIS_KEY_PREFIX,
)
app.state.redis = None
@@ -615,6 +641,15 @@ app.state.TOOL_SERVERS = []
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
+########################################
+#
+# MODELS
+#
+########################################
+
+app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE
+app.state.BASE_MODELS = []
+
########################################
#
# WEBUI
@@ -1072,7 +1107,9 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Add the middleware to the app
-app.add_middleware(CompressMiddleware)
+if ENABLE_COMPRESSION_MIDDLEWARE:
+ app.add_middleware(CompressMiddleware)
+
app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
@@ -1188,7 +1225,9 @@ if audit_level != AuditLevel.NONE:
@app.get("/api/models")
-async def get_models(request: Request, user=Depends(get_verified_user)):
+async def get_models(
+ request: Request, refresh: bool = False, user=Depends(get_verified_user)
+):
def get_filtered_models(models, user):
filtered_models = []
for model in models:
@@ -1212,7 +1251,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
- all_models = await get_all_models(request, user=user)
+ all_models = await get_all_models(request, refresh=refresh, user=user)
models = []
for model in all_models:
@@ -1507,6 +1546,7 @@ async def get_app_config(request: Request):
"name": app.state.WEBUI_NAME,
"version": VERSION,
"default_locale": str(DEFAULT_LOCALE),
+ "offline_mode": OFFLINE_MODE,
"oauth": {
"providers": {
name: config.get("name", name)
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index 0ac53a0233..9552d7f396 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -72,6 +72,8 @@ class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
folder_id: Optional[str] = None
+ created_at: Optional[int] = None
+ updated_at: Optional[int] = None
class ChatTitleMessagesForm(BaseModel):
@@ -147,8 +149,16 @@ class ChatTable:
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
- "created_at": int(time.time()),
- "updated_at": int(time.time()),
+ "created_at": (
+ form_data.created_at
+ if form_data.created_at
+ else int(time.time())
+ ),
+ "updated_at": (
+ form_data.updated_at
+ if form_data.updated_at
+ else int(time.time())
+ ),
}
)
diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py
index 8ac878fc22..a91496e8e8 100644
--- a/backend/open_webui/retrieval/loaders/main.py
+++ b/backend/open_webui/retrieval/loaders/main.py
@@ -14,7 +14,7 @@ from langchain_community.document_loaders import (
TextLoader,
UnstructuredEPubLoader,
UnstructuredExcelLoader,
- UnstructuredMarkdownLoader,
+ UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
@@ -389,6 +389,8 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
+ elif file_ext == "odt":
+ loader = UnstructuredODTLoader(file_path)
elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py
index 683f42819b..0a0f0dabab 100644
--- a/backend/open_webui/retrieval/utils.py
+++ b/backend/open_webui/retrieval/utils.py
@@ -7,6 +7,7 @@ import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
+from urllib.parse import quote
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
@@ -459,20 +460,19 @@ def get_sources_from_files(
)
extracted_collections = []
- relevant_contexts = []
+ query_results = []
for file in files:
-
- context = None
+ query_result = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
- context = {
+ query_result = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
- context = {
+ query_result = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
@@ -499,7 +499,7 @@ def get_sources_from_files(
}
)
- context = {
+ query_result = {
"documents": [documents],
"metadatas": [metadatas],
}
@@ -507,7 +507,7 @@ def get_sources_from_files(
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
- context = {
+ query_result = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
@@ -520,7 +520,7 @@ def get_sources_from_files(
],
}
elif file.get("file").get("data"):
- context = {
+ query_result = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [
[file.get("file").get("data", {}).get("metadata", {})]
@@ -548,19 +548,27 @@ def get_sources_from_files(
if full_context:
try:
- context = get_all_items_from_collections(collection_names)
+ query_result = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
- context = None
+ query_result = None
if file.get("type") == "text":
- context = file["content"]
+ # Not sure when this is used, but it seems to be a fallback
+ query_result = {
+ "documents": [
+ [file.get("file").get("data", {}).get("content")]
+ ],
+ "metadatas": [
+ [file.get("file").get("data", {}).get("meta", {})]
+ ],
+ }
else:
if hybrid_search:
try:
- context = query_collection_with_hybrid_search(
+ query_result = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
@@ -576,8 +584,8 @@ def get_sources_from_files(
" non hybrid search as fallback."
)
- if (not hybrid_search) or (context is None):
- context = query_collection(
+ if (not hybrid_search) or (query_result is None):
+ query_result = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
@@ -588,24 +596,24 @@ def get_sources_from_files(
extracted_collections.extend(collection_names)
- if context:
+ if query_result:
if "data" in file:
del file["data"]
- relevant_contexts.append({**context, "file": file})
+ query_results.append({**query_result, "file": file})
sources = []
- for context in relevant_contexts:
+ for query_result in query_results:
try:
- if "documents" in context:
- if "metadatas" in context:
+ if "documents" in query_result:
+ if "metadatas" in query_result:
source = {
- "source": context["file"],
- "document": context["documents"][0],
- "metadata": context["metadatas"][0],
+ "source": query_result["file"],
+ "document": query_result["documents"][0],
+ "metadata": query_result["metadatas"][0],
}
- if "distances" in context and context["distances"]:
- source["distances"] = context["distances"][0]
+ if "distances" in query_result and query_result["distances"]:
+ source["distances"] = query_result["distances"][0]
sources.append(source)
except Exception as e:
@@ -678,10 +686,10 @@ def generate_openai_batch_embeddings(
"Authorization": f"Bearer {key}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -727,10 +735,10 @@ def generate_azure_openai_batch_embeddings(
"api-key": key,
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -777,10 +785,10 @@ def generate_ollama_batch_embeddings(
"Authorization": f"Bearer {key}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py
index 60ef2d906c..7e16df3cfb 100644
--- a/backend/open_webui/retrieval/vector/dbs/opensearch.py
+++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py
@@ -157,10 +157,10 @@ class OpenSearchClient(VectorDBBase):
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
- size = limit if limit else 10
+ size = limit if limit else 10000
try:
result = self.client.search(
@@ -206,6 +206,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
@@ -228,6 +229,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch
]
bulk(self.client, actions)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def delete(
self,
@@ -251,11 +253,12 @@ class OpenSearchClient(VectorDBBase):
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
- {"match": {"metadata." + str(field): value}}
+ {"term": {"metadata." + str(field) + ".keyword": value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
+ self.client.indices.refresh(self._get_index_name(collection_name))
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py
index dfe2979076..2276e713fc 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py
@@ -18,6 +18,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -29,7 +30,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
@@ -86,6 +87,25 @@ class QdrantClient(VectorDBBase):
),
)
+ # Create payload indexes for efficient filtering
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.hash",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+ self.client.create_payload_index(
+ collection_name=collection_name_with_prefix,
+ field_name="metadata.file_id",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
index e83c437ef7..8f065ca5c8 100644
--- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
+++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
@@ -9,6 +9,7 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
+ QDRANT_COLLECTION_PREFIX,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
@@ -30,7 +31,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
- self.collection_prefix = "open-webui"
+ self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
@@ -228,6 +229,25 @@ class QdrantClient(VectorDBBase):
),
wait=True,
)
+ # Create payload indexes for efficient filtering on metadata.hash and metadata.file_id
+ self.client.create_payload_index(
+ collection_name=mt_collection_name,
+ field_name="metadata.hash",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
+ self.client.create_payload_index(
+ collection_name=mt_collection_name,
+ field_name="metadata.file_id",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=False,
+ on_disk=self.QDRANT_ON_DISK,
+ ),
+ )
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py
index 3075db990f..7bea575620 100644
--- a/backend/open_webui/retrieval/web/brave.py
+++ b/backend/open_webui/retrieval/web/brave.py
@@ -36,7 +36,9 @@ def search_brave(
return [
SearchResult(
- link=result["url"], title=result.get("title"), snippet=result.get("snippet")
+ link=result["url"],
+ title=result.get("title"),
+ snippet=result.get("description"),
)
for result in results[:count]
]
diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py
index 27634cec19..211f2ae859 100644
--- a/backend/open_webui/routers/audio.py
+++ b/backend/open_webui/routers/audio.py
@@ -15,6 +15,7 @@ import aiohttp
import aiofiles
import requests
import mimetypes
+from urllib.parse import quote
from fastapi import (
Depends,
@@ -343,10 +344,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
@@ -919,14 +920,18 @@ def transcription(
):
log.info(f"file.content_type: {file.content_type}")
- supported_content_types = request.app.state.config.STT_SUPPORTED_CONTENT_TYPES or [
- "audio/*",
- "video/webm",
- ]
+ stt_supported_content_types = getattr(
+ request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
+ )
if not any(
fnmatch(file.content_type, content_type)
- for content_type in supported_content_types
+ for content_type in (
+ stt_supported_content_types
+ if stt_supported_content_types
+ and any(t.strip() for t in stt_supported_content_types)
+ else ["audio/*", "video/webm"]
+ )
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py
index 60a12db4b3..106f3684a7 100644
--- a/backend/open_webui/routers/auths.py
+++ b/backend/open_webui/routers/auths.py
@@ -669,6 +669,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout")
async def signout(request: Request, response: Response):
response.delete_cookie("token")
+ response.delete_cookie("oui-session")
if ENABLE_OAUTH_SIGNUP.value:
oauth_id_token = request.cookies.get("oauth_id_token")
diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py
index 29b12ed676..13b6040102 100644
--- a/backend/open_webui/routers/chats.py
+++ b/backend/open_webui/routers/chats.py
@@ -684,8 +684,10 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
- if not has_permission(
- user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ if (user.role != "admin") and (
+ not has_permission(
+ user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
+ )
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index 44b2ef40cf..a329584ca2 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -39,32 +39,39 @@ async def export_config(user=Depends(get_admin_user)):
############################
-# Direct Connections Config
+# Connections Config
############################
-class DirectConnectionsConfigForm(BaseModel):
+class ConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool
+ ENABLE_BASE_MODELS_CACHE: bool
-@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+@router.get("/connections", response_model=ConnectionsConfigForm)
+async def get_connections_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
-@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def set_direct_connections_config(
+@router.post("/connections", response_model=ConnectionsConfigForm)
+async def set_connections_config(
request: Request,
- form_data: DirectConnectionsConfigForm,
+ form_data: ConnectionsConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS
)
+ request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
+ form_data.ENABLE_BASE_MODELS_CACHE
+ )
+
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+ "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
}
diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py
index b9bb15c7b4..bdf5780fc4 100644
--- a/backend/open_webui/routers/files.py
+++ b/backend/open_webui/routers/files.py
@@ -155,17 +155,18 @@ def upload_file(
if process:
try:
if file.content_type:
- stt_supported_content_types = (
- request.app.state.config.STT_SUPPORTED_CONTENT_TYPES
- or [
- "audio/*",
- "video/webm",
- ]
+ stt_supported_content_types = getattr(
+ request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
)
if any(
fnmatch(file.content_type, content_type)
- for content_type in stt_supported_content_types
+ for content_type in (
+ stt_supported_content_types
+ if stt_supported_content_types
+ and any(t.strip() for t in stt_supported_content_types)
+ else ["audio/*", "video/webm"]
+ )
):
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata)
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index 52686a5841..2832a11577 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -8,6 +8,7 @@ import re
from pathlib import Path
from typing import Optional
+from urllib.parse import quote
import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
@@ -302,8 +303,16 @@ async def update_image_config(
):
set_image_model(request, form_data.MODEL)
+ if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1":
+ raise HTTPException(
+ status_code=400,
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(
+ " (auto is only allowed with gpt-image-1)."
+ ),
+ )
+
pattern = r"^\d+x\d+$"
- if re.match(pattern, form_data.IMAGE_SIZE):
+ if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
@@ -471,7 +480,14 @@ async def image_generations(
form_data: GenerateImageForm,
user=Depends(get_verified_user),
):
- width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
+ # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
+ # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
+ # image model other than gpt-image-1, which is warned about on settings save
+ width, height = (
+ tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
+ if "x" in request.app.state.config.IMAGE_SIZE
+ else (512, 512)
+ )
r = None
try:
@@ -483,10 +499,10 @@ async def image_generations(
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
- headers["X-OpenWebUI-User-Name"] = user.name
- headers["X-OpenWebUI-User-Id"] = user.id
- headers["X-OpenWebUI-User-Email"] = user.email
- headers["X-OpenWebUI-User-Role"] = user.role
+ headers["X-OpenWebUI-User-Name"] = quote(user.name)
+ headers["X-OpenWebUI-User-Id"] = quote(user.id)
+ headers["X-OpenWebUI-User-Email"] = quote(user.email)
+ headers["X-OpenWebUI-User-Role"] = quote(user.role)
data = {
"model": (
diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py
index 1353599374..3887106ad2 100644
--- a/backend/open_webui/routers/ollama.py
+++ b/backend/open_webui/routers/ollama.py
@@ -16,6 +16,7 @@ from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
+from urllib.parse import quote
from open_webui.models.chats import Chats
from open_webui.models.users import UserModel
@@ -58,6 +59,7 @@ from open_webui.config import (
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -87,10 +89,10 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -138,10 +140,10 @@ async def send_post_request(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -242,10 +244,10 @@ async def verify_connection(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -329,7 +331,7 @@ def merge_ollama_models_lists(model_lists):
return list(merged_models.values())
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
@@ -462,10 +464,10 @@ async def get_ollama_tags(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -824,10 +826,10 @@ async def copy_model(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -890,10 +892,10 @@ async def delete_model(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -949,10 +951,10 @@ async def show_model_info(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -1036,10 +1038,10 @@ async def embed(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -1123,10 +1125,10 @@ async def embeddings(
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index 7649271fee..a769c9a0c9 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -8,7 +8,7 @@ from typing import Literal, Optional, overload
import aiohttp
from aiocache import cached
import requests
-
+from urllib.parse import quote
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
@@ -21,6 +21,7 @@ from open_webui.config import (
CACHE_DIR,
)
from open_webui.env import (
+ MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -66,10 +67,10 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -225,10 +226,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
@@ -386,7 +387,7 @@ async def get_filtered_models(models, user):
return filtered_models
-@cached(ttl=1)
+@cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
@@ -478,10 +479,10 @@ async def get_models(
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
@@ -573,10 +574,10 @@ async def verify_connection(
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
@@ -633,13 +634,7 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail)
-def convert_to_azure_payload(
- url,
- payload: dict,
-):
- model = payload.get("model", "")
-
- # Filter allowed parameters based on Azure OpenAI API
+def get_azure_allowed_params(api_version: str) -> set[str]:
allowed_params = {
"messages",
"temperature",
@@ -669,6 +664,23 @@ def convert_to_azure_payload(
"max_completion_tokens",
}
+ try:
+ if api_version >= "2024-09-01-preview":
+ allowed_params.add("stream_options")
+ except ValueError:
+ log.debug(
+ f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters."
+ )
+
+ return allowed_params
+
+
+def convert_to_azure_payload(url, payload: dict, api_version: str):
+ model = payload.get("model", "")
+
+ # Filter allowed parameters based on Azure OpenAI API
+ allowed_params = get_azure_allowed_params(api_version)
+
# Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"):
# Convert max_tokens to max_completion_tokens for o-series models
@@ -806,10 +818,10 @@ async def generate_chat_completion(
),
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
@@ -817,8 +829,8 @@ async def generate_chat_completion(
}
if api_config.get("azure", False):
- request_url, payload = convert_to_azure_payload(url, payload)
- api_version = api_config.get("api_version", "") or "2023-03-15-preview"
+ api_version = api_config.get("api_version", "2023-03-15-preview")
+ request_url, payload = convert_to_azure_payload(url, payload, api_version)
headers["api-key"] = key
headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}"
@@ -924,10 +936,10 @@ async def embeddings(request: Request, form_data: dict, user):
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
@@ -996,10 +1008,10 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"Content-Type": "application/json",
**(
{
- "X-OpenWebUI-User-Name": user.name,
- "X-OpenWebUI-User-Id": user.id,
- "X-OpenWebUI-User-Email": user.email,
- "X-OpenWebUI-User-Role": user.role,
+ "X-OpenWebUI-User-Name": quote(user.name),
+ "X-OpenWebUI-User-Id": quote(user.id),
+ "X-OpenWebUI-User-Email": quote(user.email),
+ "X-OpenWebUI-User-Role": quote(user.role),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
@@ -1007,16 +1019,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
}
if api_config.get("azure", False):
+ api_version = api_config.get("api_version", "2023-03-15-preview")
headers["api-key"] = key
- headers["api-version"] = (
- api_config.get("api_version", "") or "2023-03-15-preview"
- )
+ headers["api-version"] = api_version
payload = json.loads(body)
- url, payload = convert_to_azure_payload(url, payload)
+ url, payload = convert_to_azure_payload(url, payload, api_version)
body = json.dumps(payload).encode()
- request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
+ request_url = f"{url}/{path}?api-version={api_version}"
else:
headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}"
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index ee6f99fbb5..a851abc2e5 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -1747,6 +1747,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
+ elif engine == "exa":
+ if request.app.state.config.EXA_API_KEY:
+ return search_exa(
+ request.app.state.config.EXA_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ else:
+ raise Exception("No EXA_API_KEY found in environment variables")
elif engine == "searchapi":
if request.app.state.config.SEARCHAPI_API_KEY:
return search_searchapi(
@@ -1784,6 +1794,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
+ elif engine == "exa":
+ return search_exa(
+ request.app.state.config.EXA_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py
index 35e40dccb2..96bcbcf1b5 100644
--- a/backend/open_webui/socket/main.py
+++ b/backend/open_webui/socket/main.py
@@ -1,4 +1,6 @@
import asyncio
+import random
+
import socketio
import logging
import sys
@@ -105,10 +107,26 @@ else:
async def periodic_usage_pool_cleanup():
- if not aquire_func():
- log.debug("Usage pool cleanup lock already exists. Not running it.")
- return
- log.debug("Running periodic_usage_pool_cleanup")
+ max_retries = 2
+ retry_delay = random.uniform(
+ WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
+ )
+ for attempt in range(max_retries + 1):
+ if aquire_func():
+ break
+ else:
+ if attempt < max_retries:
+ log.debug(
+ f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
+ )
+ await asyncio.sleep(retry_delay)
+ else:
+ log.warning(
+ "Failed to acquire cleanup lock after retries. Skipping cleanup."
+ )
+ return
+
+ log.debug("Running periodic_cleanup")
try:
while True:
if not renew_func():
diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py
index 268c910e3e..83483f391b 100644
--- a/backend/open_webui/utils/chat.py
+++ b/backend/open_webui/utils/chat.py
@@ -419,7 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
params[key] = value
if "__user__" in sig.parameters:
- __user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
+ __user__ = user.model_dump() if isinstance(user, UserModel) else {}
try:
if hasattr(function_module, "UserValves"):
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index b1e69db264..ff4f901b76 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -248,6 +248,7 @@ async def chat_completion_tools_handler(
if tool_id
else f"{tool_function_name}"
)
+
if tool.get("metadata", {}).get("citation", False) or tool.get(
"direct", False
):
@@ -718,6 +719,10 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, user, metadata, model):
+ # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
+ # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
+ # -> Chat Files
+
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@@ -804,7 +809,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise e
try:
-
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
@@ -912,7 +916,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
-
except Exception as e:
log.exception(e)
@@ -925,24 +928,27 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
- citation_idx = {}
+ citation_idx_map = {}
+
for source in sources:
if "document" in source:
- for doc_context, doc_meta in zip(
+ for document_text, document_metadata in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
- citation_id = (
- doc_meta.get("source", None)
+ source_id = (
+ document_metadata.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
- if citation_id not in citation_idx:
- citation_idx[citation_id] = len(citation_idx) + 1
+
+ if source_id not in citation_idx_map:
+ citation_idx_map[source_id] = len(citation_idx_map) + 1
+
context_string += (
- f'