diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 2f5f34019d..7e5c35a451 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -2767,6 +2767,12 @@ WEB_SEARCH_TRUST_ENV = PersistentConfig(
)
+OLLAMA_CLOUD_WEB_SEARCH_API_KEY = PersistentConfig(
+ "OLLAMA_CLOUD_WEB_SEARCH_API_KEY",
+ "rag.web.search.ollama_cloud_api_key",
+ os.getenv("OLLAMA_CLOUD_API_KEY", ""),
+)
+
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index 243b8212a8..e02424f969 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -474,6 +474,10 @@ ENABLE_OAUTH_ID_TOKEN_COOKIE = (
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
)
+OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get(
+ "OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY
+)
+
OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
)
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 7ebb76cb58..c6930f9b99 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -50,6 +50,11 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from starlette.datastructures import Headers
+from starsessions import (
+ SessionMiddleware as StarSessionsMiddleware,
+ SessionAutoloadMiddleware,
+)
+from starsessions.stores.redis import RedisStore
from open_webui.utils import logger
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
@@ -269,6 +274,7 @@ from open_webui.config import (
WEB_SEARCH_CONCURRENT_REQUESTS,
WEB_SEARCH_TRUST_ENV,
WEB_SEARCH_DOMAIN_FILTER_LIST,
+ OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
JINA_API_KEY,
SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE,
@@ -467,7 +473,12 @@ from open_webui.utils.auth import (
get_verified_user,
)
from open_webui.utils.plugin import install_tool_and_function_dependencies
-from open_webui.utils.oauth import OAuthManager
+from open_webui.utils.oauth import (
+ OAuthManager,
+ OAuthClientManager,
+ decrypt_data,
+ OAuthClientInformationFull,
+)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
@@ -597,9 +608,14 @@ app = FastAPI(
lifespan=lifespan,
)
+# For Open WebUI OIDC/OAuth2
oauth_manager = OAuthManager(app)
app.state.oauth_manager = oauth_manager
+# For Integrations
+oauth_client_manager = OAuthClientManager(app)
+app.state.oauth_client_manager = oauth_client_manager
+
app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
@@ -883,6 +899,8 @@ app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
+
+app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = OLLAMA_CLOUD_WEB_SEARCH_API_KEY
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.YACY_QUERY_URL = YACY_QUERY_URL
app.state.config.YACY_USERNAME = YACY_USERNAME
@@ -1873,14 +1891,78 @@ async def get_current_usage(user=Depends(get_verified_user)):
# OAuth Login & Callback
############################
+
+# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1
+if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
+ for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS:
+ if tool_server_connection.get("type", "openapi") == "mcp":
+ server_id = tool_server_connection.get("info", {}).get("id")
+ auth_type = tool_server_connection.get("auth_type", "none")
+ if server_id and auth_type == "oauth_2.1":
+ oauth_client_info = tool_server_connection.get("info", {}).get(
+ "oauth_client_info", ""
+ )
+
+ oauth_client_info = decrypt_data(oauth_client_info)
+ app.state.oauth_client_manager.add_client(
+ f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info)
+ )
+
+
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
- app.add_middleware(
- SessionMiddleware,
- secret_key=WEBUI_SECRET_KEY,
- session_cookie="oui-session",
- same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
- https_only=WEBUI_SESSION_COOKIE_SECURE,
+ try:
+ if REDIS_URL:
+ redis_session_store = RedisStore(
+ url=REDIS_URL,
+ prefix=(
+ f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:"
+ ),
+ )
+
+ app.add_middleware(SessionAutoloadMiddleware)
+ app.add_middleware(
+ StarSessionsMiddleware,
+ store=redis_session_store,
+ cookie_name="oui-session",
+ cookie_same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
+ cookie_https_only=WEBUI_SESSION_COOKIE_SECURE,
+ )
+ log.info("Using Redis for session")
+ else:
+ raise ValueError("No Redis URL provided")
+ except Exception as e:
+ app.add_middleware(
+ SessionMiddleware,
+ secret_key=WEBUI_SECRET_KEY,
+ session_cookie="oui-session",
+ same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
+ https_only=WEBUI_SESSION_COOKIE_SECURE,
+ )
+
+
+@app.get("/oauth/clients/{client_id}/authorize")
+async def oauth_client_authorize(
+ client_id: str,
+ request: Request,
+ response: Response,
+ user=Depends(get_verified_user),
+):
+ return await oauth_client_manager.handle_authorize(request, client_id=client_id)
+
+
+@app.get("/oauth/clients/{client_id}/callback")
+async def oauth_client_callback(
+ client_id: str,
+ request: Request,
+ response: Response,
+ user=Depends(get_verified_user),
+):
+ return await oauth_client_manager.handle_callback(
+ request,
+ client_id=client_id,
+ user_id=user.id if user else None,
+ response=response,
)
@@ -1895,8 +1977,9 @@ async def oauth_login(provider: str, request: Request):
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is already taken
-@app.get("/oauth/{provider}/callback")
-async def oauth_callback(provider: str, request: Request, response: Response):
+@app.get("/oauth/{provider}/callback") # Legacy endpoint
+@app.get("/oauth/{provider}/login/callback")
+async def oauth_login_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(request, provider, response)
diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py
index 92f238c3a0..e75266be78 100644
--- a/backend/open_webui/models/channels.py
+++ b/backend/open_webui/models/channels.py
@@ -57,6 +57,10 @@ class ChannelModel(BaseModel):
####################
+class ChannelResponse(ChannelModel):
+ write_access: bool = False
+
+
class ChannelForm(BaseModel):
name: str
description: Optional[str] = None
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index cadb5a3a79..97fd9b6256 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -492,11 +492,16 @@ class ChatTable:
self,
user_id: str,
include_archived: bool = False,
+ include_folders: bool = False,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
- query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
+ query = db.query(Chat).filter_by(user_id=user_id)
+
+ if not include_folders:
+ query = query.filter_by(folder_id=None)
+
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived:
@@ -943,6 +948,16 @@ class ChatTable:
return count
+ def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int:
+ with get_db() as db:
+ query = db.query(Chat).filter_by(user_id=user_id)
+
+ query = query.filter_by(folder_id=folder_id)
+ count = query.count()
+
+ log.info(f"Count of chats for folder '{folder_id}': {count}")
+ return count
+
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py
index 57978225d4..bf07b5f86f 100644
--- a/backend/open_webui/models/files.py
+++ b/backend/open_webui/models/files.py
@@ -130,6 +130,17 @@ class FilesTable:
except Exception:
return None
+ def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
+ with get_db() as db:
+ try:
+ file = db.query(File).filter_by(id=id, user_id=user_id).first()
+ if file:
+ return FileModel.model_validate(file)
+ else:
+ return None
+ except Exception:
+ return None
+
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
try:
diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py
index 9fd5335ce5..81ce220384 100644
--- a/backend/open_webui/models/oauth_sessions.py
+++ b/backend/open_webui/models/oauth_sessions.py
@@ -176,6 +176,26 @@ class OAuthSessionTable:
log.error(f"Error getting OAuth session by ID: {e}")
return None
+ def get_session_by_provider_and_user_id(
+ self, provider: str, user_id: str
+ ) -> Optional[OAuthSessionModel]:
+ """Get OAuth session by provider and user ID"""
+ try:
+ with get_db() as db:
+ session = (
+ db.query(OAuthSession)
+ .filter_by(provider=provider, user_id=user_id)
+ .first()
+ )
+ if session:
+ session.token = self._decrypt_token(session.token)
+ return OAuthSessionModel.model_validate(session)
+
+ return None
+ except Exception as e:
+ log.error(f"Error getting OAuth session by provider and user ID: {e}")
+ return None
+
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py
index 3a47fa008d..48f84b3ac4 100644
--- a/backend/open_webui/models/tools.py
+++ b/backend/open_webui/models/tools.py
@@ -95,6 +95,8 @@ class ToolResponse(BaseModel):
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
+ model_config = ConfigDict(extra="allow")
+
class ToolForm(BaseModel):
id: str
diff --git a/backend/open_webui/retrieval/web/ollama.py b/backend/open_webui/retrieval/web/ollama.py
new file mode 100644
index 0000000000..a199a14389
--- /dev/null
+++ b/backend/open_webui/retrieval/web/ollama.py
@@ -0,0 +1,51 @@
+import logging
+from dataclasses import dataclass
+from typing import Optional
+
+import requests
+from open_webui.env import SRC_LOG_LEVELS
+from open_webui.retrieval.web.main import SearchResult
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_ollama_cloud(
+ url: str,
+ api_key: str,
+ query: str,
+ count: int,
+ filter_list: Optional[list[str]] = None,
+) -> list[SearchResult]:
+ """Search using Ollama Search API and return the results as a list of SearchResult objects.
+
+ Args:
+ api_key (str): A Ollama Search API key
+ query (str): The query to search for
+ count (int): Number of results to return
+ filter_list (Optional[list[str]]): List of domains to filter results by
+ """
+ log.info(f"Searching with Ollama for query: {query}")
+
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+ payload = {"query": query, "max_results": count}
+
+ try:
+ response = requests.post(f"{url}/api/web_search", headers=headers, json=payload)
+ response.raise_for_status()
+ data = response.json()
+
+ results = data.get("results", [])
+ log.info(f"Found {len(results)} results")
+
+ return [
+ SearchResult(
+ link=result.get("url", ""),
+ title=result.get("title", ""),
+ snippet=result.get("content", ""),
+ )
+ for result in results
+ ]
+ except Exception as e:
+ log.error(f"Error searching Ollama: {e}")
+ return []
diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py
index da52be6e79..e7b8366347 100644
--- a/backend/open_webui/routers/channels.py
+++ b/backend/open_webui/routers/channels.py
@@ -10,7 +10,13 @@ from pydantic import BaseModel
from open_webui.socket.main import sio, get_user_ids_from_room
from open_webui.models.users import Users, UserNameResponse
-from open_webui.models.channels import Channels, ChannelModel, ChannelForm
+from open_webui.models.groups import Groups
+from open_webui.models.channels import (
+ Channels,
+ ChannelModel,
+ ChannelForm,
+ ChannelResponse,
+)
from open_webui.models.messages import (
Messages,
MessageModel,
@@ -80,7 +86,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
############################
-@router.get("/{id}", response_model=Optional[ChannelModel])
+@router.get("/{id}", response_model=Optional[ChannelResponse])
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
channel = Channels.get_channel_by_id(id)
if not channel:
@@ -95,7 +101,16 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
- return ChannelModel(**channel.model_dump())
+ write_access = has_access(
+ user.id, type="write", access_control=channel.access_control, strict=False
+ )
+
+ return ChannelResponse(
+ **{
+ **channel.model_dump(),
+ "write_access": write_access or user.role == "admin",
+ }
+ )
############################
@@ -275,6 +290,7 @@ async def model_response_handler(request, channel, message, user):
)
thread_history = []
+ images = []
message_users = {}
for thread_message in thread_messages:
@@ -303,6 +319,11 @@ async def model_response_handler(request, channel, message, user):
f"{username}: {replace_mentions(thread_message.content)}"
)
+ thread_message_files = thread_message.data.get("files", [])
+ for file in thread_message_files:
+ if file.get("type", "") == "image":
+ images.append(file.get("url", ""))
+
system_message = {
"role": "system",
"content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational."
@@ -313,14 +334,29 @@ async def model_response_handler(request, channel, message, user):
),
}
+ content = f"{user.name if user else 'User'}: {message_content}"
+ if images:
+ content = [
+ {
+ "type": "text",
+ "text": content,
+ },
+ *[
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": image,
+ },
+ }
+ for image in images
+ ],
+ ]
+
form_data = {
"model": model_id,
"messages": [
system_message,
- {
- "role": "user",
- "content": f"{user.name if user else 'User'}: {message_content}",
- },
+ {"role": "user", "content": content},
],
"stream": False,
}
@@ -362,7 +398,7 @@ async def new_message_handler(
)
if user.role != "admin" and not has_access(
- user.id, type="read", access_control=channel.access_control
+ user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@@ -658,7 +694,7 @@ async def add_reaction_to_message(
)
if user.role != "admin" and not has_access(
- user.id, type="read", access_control=channel.access_control
+ user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@@ -724,7 +760,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
)
if user.role != "admin" and not has_access(
- user.id, type="read", access_control=channel.access_control
+ user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@@ -806,7 +842,9 @@ async def delete_message_by_id(
if (
user.role != "admin"
and message.user_id != user.id
- and not has_access(user.id, type="read", access_control=channel.access_control)
+ and not has_access(
+ user.id, type="write", access_control=channel.access_control, strict=False
+ )
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py
index 847368412e..788e355f2b 100644
--- a/backend/open_webui/routers/chats.py
+++ b/backend/open_webui/routers/chats.py
@@ -37,7 +37,9 @@ router = APIRouter()
@router.get("/", response_model=list[ChatTitleIdResponse])
@router.get("/list", response_model=list[ChatTitleIdResponse])
def get_session_user_chat_list(
- user=Depends(get_verified_user), page: Optional[int] = None
+ user=Depends(get_verified_user),
+ page: Optional[int] = None,
+ include_folders: Optional[bool] = False,
):
try:
if page is not None:
@@ -45,10 +47,12 @@ def get_session_user_chat_list(
skip = (page - 1) * limit
return Chats.get_chat_title_id_list_by_user_id(
- user.id, skip=skip, limit=limit
+ user.id, include_folders=include_folders, skip=skip, limit=limit
)
else:
- return Chats.get_chat_title_id_list_by_user_id(user.id)
+ return Chats.get_chat_title_id_list_by_user_id(
+ user.id, include_folders=include_folders
+ )
except Exception as e:
log.exception(e)
raise HTTPException(
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index 31d7bce404..d4b88032e2 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -1,6 +1,7 @@
import logging
from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict
+import aiohttp
from typing import Optional
@@ -17,6 +18,14 @@ from open_webui.utils.mcp.client import MCPClient
from open_webui.env import SRC_LOG_LEVELS
+from open_webui.utils.oauth import (
+ get_discovery_urls,
+ get_oauth_client_info_with_dynamic_client_registration,
+ encrypt_data,
+ decrypt_data,
+ OAuthClientInformationFull,
+)
+from mcp.shared.auth import OAuthMetadata
router = APIRouter()
@@ -86,6 +95,43 @@ async def set_connections_config(
}
+class OAuthClientRegistrationForm(BaseModel):
+ url: str
+ client_id: str
+ client_name: Optional[str] = None
+
+
+@router.post("/oauth/clients/register")
+async def register_oauth_client(
+ request: Request,
+ form_data: OAuthClientRegistrationForm,
+ type: Optional[str] = None,
+ user=Depends(get_admin_user),
+):
+ try:
+ oauth_client_id = form_data.client_id
+ if type:
+ oauth_client_id = f"{type}:{form_data.client_id}"
+
+ oauth_client_info = (
+ await get_oauth_client_info_with_dynamic_client_registration(
+ request, oauth_client_id, form_data.url
+ )
+ )
+ return {
+ "status": True,
+ "oauth_client_info": encrypt_data(
+ oauth_client_info.model_dump(mode="json")
+ ),
+ }
+ except Exception as e:
+ log.debug(f"Failed to register OAuth client: {e}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to register OAuth client",
+ )
+
+
############################
# ToolServers Config
############################
@@ -122,8 +168,29 @@ async def set_tool_servers_config(
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
+
await set_tool_servers(request)
+ for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+ server_type = connection.get("type", "openapi")
+ if server_type == "mcp":
+ server_id = connection.get("info", {}).get("id")
+ auth_type = connection.get("auth_type", "none")
+ if auth_type == "oauth_2.1" and server_id:
+ try:
+ oauth_client_info = connection.get("info", {}).get(
+ "oauth_client_info", ""
+ )
+ oauth_client_info = decrypt_data(oauth_client_info)
+
+ await request.app.state.oauth_client_manager.add_client(
+ f"{server_type}:{server_id}",
+ OAuthClientInformationFull(**oauth_client_info),
+ )
+ except Exception as e:
+ log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
+ continue
+
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@@ -138,46 +205,79 @@ async def verify_tool_servers_config(
"""
try:
if form_data.type == "mcp":
- try:
- client = MCPClient()
- auth = None
- headers = None
+ if form_data.auth_type == "oauth_2.1":
+ discovery_urls = get_discovery_urls(form_data.url)
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ discovery_urls[0]
+ ) as oauth_server_metadata_response:
+ if oauth_server_metadata_response.status != 200:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}",
+ )
- token = None
- if form_data.auth_type == "bearer":
- token = form_data.key
- elif form_data.auth_type == "session":
- token = request.state.token.credentials
- elif form_data.auth_type == "system_oauth":
- try:
- if request.cookies.get("oauth_session_id", None):
- token = (
- await request.app.state.oauth_manager.get_oauth_token(
+ try:
+ oauth_server_metadata = OAuthMetadata.model_validate(
+ await oauth_server_metadata_response.json()
+ )
+ return {
+ "status": True,
+ "oauth_server_metadata": oauth_server_metadata.model_dump(
+ mode="json"
+ ),
+ }
+ except Exception as e:
+ log.info(
+ f"Failed to parse OAuth 2.1 discovery document: {e}"
+ )
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}",
+ )
+
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}",
+ )
+ else:
+ try:
+ client = MCPClient()
+ headers = None
+
+ token = None
+ if form_data.auth_type == "bearer":
+ token = form_data.key
+ elif form_data.auth_type == "session":
+ token = request.state.token.credentials
+ elif form_data.auth_type == "system_oauth":
+ try:
+ if request.cookies.get("oauth_session_id", None):
+ token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
- )
- except Exception as e:
- pass
+ except Exception as e:
+ pass
- if token:
- headers = {"Authorization": f"Bearer {token}"}
+ if token:
+ headers = {"Authorization": f"Bearer {token}"}
- await client.connect(form_data.url, auth=auth, headers=headers)
- specs = await client.list_tool_specs()
- return {
- "status": True,
- "specs": specs,
- }
- except Exception as e:
- log.debug(f"Failed to create MCP client: {e}")
- raise HTTPException(
- status_code=400,
- detail=f"Failed to create MCP client",
- )
- finally:
- if client:
- await client.disconnect()
+ await client.connect(form_data.url, headers=headers)
+ specs = await client.list_tool_specs()
+ return {
+ "status": True,
+ "specs": specs,
+ }
+ except Exception as e:
+ log.debug(f"Failed to create MCP client: {e}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to create MCP client",
+ )
+ finally:
+ if client:
+ await client.disconnect()
else: # openapi
token = None
if form_data.auth_type == "bearer":
diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py
index 36dbfee5c5..ddee71ea4d 100644
--- a/backend/open_webui/routers/folders.py
+++ b/backend/open_webui/routers/folders.py
@@ -262,15 +262,15 @@ async def update_folder_is_expanded_by_id(
async def delete_folder_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
- chat_delete_permission = has_permission(
- user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
- )
-
- if user.role != "admin" and not chat_delete_permission:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+ if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
+ chat_delete_permission = has_permission(
+ user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
)
+ if user.role != "admin" and not chat_delete_permission:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+ )
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py
index 202aa74ca4..c36e656d5f 100644
--- a/backend/open_webui/routers/functions.py
+++ b/backend/open_webui/routers/functions.py
@@ -431,8 +431,10 @@ async def update_function_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
- Functions.update_function_valves_by_id(id, valves.model_dump())
- return valves.model_dump()
+
+ valves_dict = valves.model_dump(exclude_unset=True)
+ Functions.update_function_valves_by_id(id, valves_dict)
+ return valves_dict
except Exception as e:
log.exception(f"Error updating function values by id {id}: {e}")
raise HTTPException(
@@ -514,10 +516,11 @@ async def update_function_user_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
+ user_valves_dict = user_valves.model_dump(exclude_unset=True)
Functions.update_user_valves_by_id_and_user_id(
- id, user.id, user_valves.model_dump()
+ id, user.id, user_valves_dict
)
- return user_valves.model_dump()
+ return user_valves_dict
except Exception as e:
log.exception(f"Error updating function user valves by id {id}: {e}")
raise HTTPException(
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index 802a3e9924..059b3a23d7 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -514,6 +514,7 @@ async def image_generations(
size = form_data.size
width, height = tuple(map(int, size.split("x")))
+ model = get_image_model(request)
r = None
try:
@@ -531,11 +532,7 @@ async def image_generations(
headers["X-OpenWebUI-User-Role"] = user.role
data = {
- "model": (
- request.app.state.config.IMAGE_GENERATION_MODEL
- if request.app.state.config.IMAGE_GENERATION_MODEL != ""
- else "dall-e-2"
- ),
+ "model": model,
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
@@ -584,7 +581,6 @@ async def image_generations(
headers["Content-Type"] = "application/json"
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
- model = get_image_model(request)
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
@@ -640,7 +636,7 @@ async def image_generations(
}
)
res = await comfyui_generate_image(
- request.app.state.config.IMAGE_GENERATION_MODEL,
+ model,
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index 0ddf824efa..73b3a22725 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -45,6 +45,7 @@ from open_webui.retrieval.loaders.youtube import YoutubeLoader
# Web search engines
from open_webui.retrieval.web.main import SearchResult
from open_webui.retrieval.web.utils import get_web_loader
+from open_webui.retrieval.web.ollama import search_ollama_cloud
from open_webui.retrieval.web.brave import search_brave
from open_webui.retrieval.web.kagi import search_kagi
from open_webui.retrieval.web.mojeek import search_mojeek
@@ -469,6 +470,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
+ "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
@@ -525,6 +527,7 @@ class WebConfig(BaseModel):
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
+ OLLAMA_CLOUD_WEB_SEARCH_API_KEY: Optional[str] = None
SEARXNG_QUERY_URL: Optional[str] = None
YACY_QUERY_URL: Optional[str] = None
YACY_USERNAME: Optional[str] = None
@@ -988,6 +991,9 @@ async def update_rag_config(
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
)
+ request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = (
+ form_data.web.OLLAMA_CLOUD_WEB_SEARCH_API_KEY
+ )
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
@@ -1139,6 +1145,7 @@ async def update_rag_config(
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
+ "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
@@ -1407,59 +1414,35 @@ def process_file(
form_data: ProcessFileForm,
user=Depends(get_verified_user),
):
- try:
+ if user.role == "admin":
file = Files.get_file_by_id(form_data.file_id)
+ else:
+ file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id)
- collection_name = form_data.collection_name
+ if file:
+ try:
- if collection_name is None:
- collection_name = f"file-{file.id}"
+ collection_name = form_data.collection_name
- if form_data.content:
- # Update the content in the file
- # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
+ if collection_name is None:
+ collection_name = f"file-{file.id}"
- try:
- # /files/{file_id}/data/content/update
- VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
- except:
- # Audio file upload pipeline
- pass
+ if form_data.content:
+ # Update the content in the file
+ # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
- docs = [
- Document(
- page_content=form_data.content.replace("
", "\n"),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
-
- text_content = form_data.content
- elif form_data.collection_name:
- # Check if the file has already been processed and save the content
- # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
-
- result = VECTOR_DB_CLIENT.query(
- collection_name=f"file-{file.id}", filter={"file_id": file.id}
- )
-
- if result is not None and len(result.ids[0]) > 0:
- docs = [
- Document(
- page_content=result.documents[0][idx],
- metadata=result.metadatas[0][idx],
+ try:
+ # /files/{file_id}/data/content/update
+ VECTOR_DB_CLIENT.delete_collection(
+ collection_name=f"file-{file.id}"
)
- for idx, id in enumerate(result.ids[0])
- ]
- else:
+ except:
+ # Audio file upload pipeline
+ pass
+
docs = [
Document(
- page_content=file.data.get("content", ""),
+ page_content=form_data.content.replace("
", "\n"),
metadata={
**file.meta,
"name": file.filename,
@@ -1470,149 +1453,190 @@ def process_file(
)
]
- text_content = file.data.get("content", "")
- else:
- # Process the file and save the content
- # Usage: /files/
- file_path = file.path
- if file_path:
- file_path = Storage.get_file(file_path)
- loader = Loader(
- engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
- DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
- DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
- DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
- DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
- DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
- DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
- DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
- DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
- DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
- DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
- DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
- EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
- EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
- TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
- DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
- DOCLING_PARAMS={
- "do_ocr": request.app.state.config.DOCLING_DO_OCR,
- "force_ocr": request.app.state.config.DOCLING_FORCE_OCR,
- "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
- "ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
- "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND,
- "table_mode": request.app.state.config.DOCLING_TABLE_MODE,
- "pipeline": request.app.state.config.DOCLING_PIPELINE,
- "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
- "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
- "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
- "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
- },
- PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
- DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
- DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
- MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
- )
- docs = loader.load(
- file.filename, file.meta.get("content_type"), file_path
+ text_content = form_data.content
+ elif form_data.collection_name:
+ # Check if the file has already been processed and save the content
+ # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
+
+ result = VECTOR_DB_CLIENT.query(
+ collection_name=f"file-{file.id}", filter={"file_id": file.id}
)
- docs = [
- Document(
- page_content=doc.page_content,
- metadata={
- **doc.metadata,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- for doc in docs
- ]
- else:
- docs = [
- Document(
- page_content=file.data.get("content", ""),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
- text_content = " ".join([doc.page_content for doc in docs])
-
- log.debug(f"text_content: {text_content}")
- Files.update_file_data_by_id(
- file.id,
- {"content": text_content},
- )
- hash = calculate_sha256_string(text_content)
- Files.update_file_hash_by_id(file.id, hash)
-
- if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
- Files.update_file_data_by_id(file.id, {"status": "completed"})
- return {
- "status": True,
- "collection_name": None,
- "filename": file.filename,
- "content": text_content,
- }
- else:
- try:
- result = save_docs_to_vector_db(
- request,
- docs=docs,
- collection_name=collection_name,
- metadata={
- "file_id": file.id,
- "name": file.filename,
- "hash": hash,
- },
- add=(True if form_data.collection_name else False),
- user=user,
- )
- log.info(f"added {len(docs)} items to collection {collection_name}")
-
- if result:
- Files.update_file_metadata_by_id(
- file.id,
- {
- "collection_name": collection_name,
- },
- )
-
- Files.update_file_data_by_id(
- file.id,
- {"status": "completed"},
- )
-
- return {
- "status": True,
- "collection_name": collection_name,
- "filename": file.filename,
- "content": text_content,
- }
+ if result is not None and len(result.ids[0]) > 0:
+ docs = [
+ Document(
+ page_content=result.documents[0][idx],
+ metadata=result.metadatas[0][idx],
+ )
+ for idx, id in enumerate(result.ids[0])
+ ]
else:
- raise Exception("Error saving document to vector database")
- except Exception as e:
- raise e
+ docs = [
+ Document(
+ page_content=file.data.get("content", ""),
+ metadata={
+ **file.meta,
+ "name": file.filename,
+ "created_by": file.user_id,
+ "file_id": file.id,
+ "source": file.filename,
+ },
+ )
+ ]
- except Exception as e:
- log.exception(e)
- if "No pandoc was found" in str(e):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
+ text_content = file.data.get("content", "")
+ else:
+ # Process the file and save the content
+ # Usage: /files/
+ file_path = file.path
+ if file_path:
+ file_path = Storage.get_file(file_path)
+ loader = Loader(
+ engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
+ DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
+ DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
+ DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
+ DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
+ DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
+ DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
+ DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
+ DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
+ DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
+ DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
+ DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
+ EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
+ EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
+ TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
+ DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
+ DOCLING_PARAMS={
+ "do_ocr": request.app.state.config.DOCLING_DO_OCR,
+ "force_ocr": request.app.state.config.DOCLING_FORCE_OCR,
+ "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
+ "ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
+ "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND,
+ "table_mode": request.app.state.config.DOCLING_TABLE_MODE,
+ "pipeline": request.app.state.config.DOCLING_PIPELINE,
+ "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
+ "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
+ "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
+ "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
+ },
+ PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
+ DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
+ DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
+ MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
+ )
+ docs = loader.load(
+ file.filename, file.meta.get("content_type"), file_path
+ )
+
+ docs = [
+ Document(
+ page_content=doc.page_content,
+ metadata={
+ **doc.metadata,
+ "name": file.filename,
+ "created_by": file.user_id,
+ "file_id": file.id,
+ "source": file.filename,
+ },
+ )
+ for doc in docs
+ ]
+ else:
+ docs = [
+ Document(
+ page_content=file.data.get("content", ""),
+ metadata={
+ **file.meta,
+ "name": file.filename,
+ "created_by": file.user_id,
+ "file_id": file.id,
+ "source": file.filename,
+ },
+ )
+ ]
+ text_content = " ".join([doc.page_content for doc in docs])
+
+ log.debug(f"text_content: {text_content}")
+ Files.update_file_data_by_id(
+ file.id,
+ {"content": text_content},
)
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e),
+ hash = calculate_sha256_string(text_content)
+ Files.update_file_hash_by_id(file.id, hash)
+
+ if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
+ Files.update_file_data_by_id(file.id, {"status": "completed"})
+ return {
+ "status": True,
+ "collection_name": None,
+ "filename": file.filename,
+ "content": text_content,
+ }
+ else:
+ try:
+ result = save_docs_to_vector_db(
+ request,
+ docs=docs,
+ collection_name=collection_name,
+ metadata={
+ "file_id": file.id,
+ "name": file.filename,
+ "hash": hash,
+ },
+ add=(True if form_data.collection_name else False),
+ user=user,
+ )
+ log.info(f"added {len(docs)} items to collection {collection_name}")
+
+ if result:
+ Files.update_file_metadata_by_id(
+ file.id,
+ {
+ "collection_name": collection_name,
+ },
+ )
+
+ Files.update_file_data_by_id(
+ file.id,
+ {"status": "completed"},
+ )
+
+ return {
+ "status": True,
+ "collection_name": collection_name,
+ "filename": file.filename,
+ "content": text_content,
+ }
+ else:
+ raise Exception("Error saving document to vector database")
+ except Exception as e:
+ raise e
+
+ except Exception as e:
+ log.exception(e)
+ Files.update_file_data_by_id(
+ file.id,
+ {"status": "failed"},
)
+ if "No pandoc was found" in str(e):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=str(e),
+ )
+
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+ )
+
class ProcessTextForm(BaseModel):
name: str
@@ -1769,7 +1793,15 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
"""
# TODO: add playwright to search the web
- if engine == "searxng":
+ if engine == "ollama_cloud":
+ return search_ollama_cloud(
+ "https://ollama.com",
+ request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ elif engine == "searxng":
if request.app.state.config.SEARXNG_QUERY_URL:
return search_searxng(
request.app.state.config.SEARXNG_QUERY_URL,
diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py
index 71c7069fd3..eb66a86825 100644
--- a/backend/open_webui/routers/tools.py
+++ b/backend/open_webui/routers/tools.py
@@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl
from fastapi import APIRouter, Depends, HTTPException, Request, status
+from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.tools import (
ToolForm,
ToolModel,
@@ -41,7 +42,15 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(request: Request, user=Depends(get_verified_user)):
- tools = Tools.get_tools()
+ tools = [
+ ToolUserResponse(
+ **{
+ **tool.model_dump(),
+ "has_user_valves": "class UserValves(BaseModel):" in tool.content,
+ }
+ )
+ for tool in Tools.get_tools()
+ ]
# OpenAPI Tool Servers
for server in await get_tool_servers(request):
@@ -72,6 +81,20 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
# MCP Tool Servers
for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if server.get("type", "openapi") == "mcp":
+ server_id = server.get("info", {}).get("id")
+ auth_type = server.get("auth_type", "none")
+
+ session_token = None
+ if auth_type == "oauth_2.1":
+ splits = server_id.split(":")
+ server_id = splits[-1] if len(splits) > 1 else server_id
+
+ session_token = (
+ await request.app.state.oauth_client_manager.get_oauth_token(
+ user.id, f"mcp:{server_id}"
+ )
+ )
+
tools.append(
ToolUserResponse(
**{
@@ -88,6 +111,13 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
),
"updated_at": int(time.time()),
"created_at": int(time.time()),
+ **(
+ {
+ "authenticated": session_token is not None,
+ }
+ if auth_type == "oauth_2.1"
+ else {}
+ ),
}
)
)
@@ -486,8 +516,9 @@ async def update_tools_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
- Tools.update_tool_valves_by_id(id, valves.model_dump())
- return valves.model_dump()
+ valves_dict = valves.model_dump(exclude_unset=True)
+ Tools.update_tool_valves_by_id(id, valves_dict)
+ return valves_dict
except Exception as e:
log.exception(f"Failed to update tool valves by id {id}: {e}")
raise HTTPException(
@@ -562,10 +593,11 @@ async def update_tools_user_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
+ user_valves_dict = user_valves.model_dump(exclude_unset=True)
Tools.update_user_valves_by_id_and_user_id(
- id, user.id, user_valves.model_dump()
+ id, user.id, user_valves_dict
)
- return user_valves.model_dump()
+ return user_valves_dict
except Exception as e:
log.exception(f"Failed to update user valves by id {id}: {e}")
raise HTTPException(
diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py
index 6215a6ac22..af48bebfb4 100644
--- a/backend/open_webui/utils/access_control.py
+++ b/backend/open_webui/utils/access_control.py
@@ -110,9 +110,13 @@ def has_access(
type: str = "write",
access_control: Optional[dict] = None,
user_group_ids: Optional[Set[str]] = None,
+ strict: bool = True,
) -> bool:
if access_control is None:
- return type == "read"
+ if strict:
+ return type == "read"
+ else:
+ return True
if user_group_ids is None:
user_groups = Groups.get_groups_by_member_id(user_id)
diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py
index 2d352ead24..01df38886c 100644
--- a/backend/open_webui/utils/mcp/client.py
+++ b/backend/open_webui/utils/mcp/client.py
@@ -13,13 +13,9 @@ class MCPClient:
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
- async def connect(
- self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
- ):
+ async def connect(self, url: str, headers: Optional[dict] = None):
try:
- self._streams_context = streamablehttp_client(
- url, headers=headers, auth=auth
- )
+ self._streams_context = streamablehttp_client(url, headers=headers)
transport = await self.exit_stack.enter_async_context(self._streams_context)
read_stream, write_stream, _ = transport
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 52c78f199c..509f419b07 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
from starlette.responses import Response, StreamingResponse, JSONResponse
+from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.chats import Chats
from open_webui.models.folders import Folders
from open_webui.models.users import Users
@@ -1047,6 +1048,22 @@ async def process_chat_payload(request, form_data, user, metadata, model):
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
+ elif auth_type == "oauth_2.1":
+ try:
+ splits = server_id.split(":")
+ server_id = splits[-1] if len(splits) > 1 else server_id
+
+ oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
+ user.id, f"mcp:{server_id}"
+ )
+
+ if oauth_token:
+ headers["Authorization"] = (
+ f"Bearer {oauth_token.get('access_token', '')}"
+ )
+ except Exception as e:
+ log.error(f"Error getting OAuth token: {e}")
+ oauth_token = None
mcp_client = MCPClient()
await mcp_client.connect(
@@ -1171,26 +1188,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise Exception("No user message found")
if context_string != "":
- # Workaround for Ollama 2.0+ system prompt issue
- # TODO: replace with add_or_update_system_message
- if model.get("owned_by") == "ollama":
- form_data["messages"] = prepend_to_first_user_message_content(
- rag_template(
- request.app.state.config.RAG_TEMPLATE,
- context_string,
- prompt,
- ),
- form_data["messages"],
- )
- else:
- form_data["messages"] = add_or_update_system_message(
- rag_template(
- request.app.state.config.RAG_TEMPLATE,
- context_string,
- prompt,
- ),
- form_data["messages"],
- )
+ form_data["messages"] = add_or_update_user_message(
+ rag_template(
+ request.app.state.config.RAG_TEMPLATE,
+ context_string,
+ prompt,
+ ),
+ form_data["messages"],
+ append=False,
+ )
# If there are citations, add them to the data_items
sources = [
diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py
index 370cf26c48..e8cfa0d158 100644
--- a/backend/open_webui/utils/misc.py
+++ b/backend/open_webui/utils/misc.py
@@ -120,19 +120,20 @@ def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]
return get_system_message(messages), remove_system_message(messages)
-def prepend_to_first_user_message_content(
- content: str, messages: list[dict]
-) -> list[dict]:
- for message in messages:
- if message["role"] == "user":
- if isinstance(message["content"], list):
- for item in message["content"]:
- if item["type"] == "text":
- item["text"] = f"{content}\n{item['text']}"
- else:
- message["content"] = f"{content}\n{message['content']}"
- break
- return messages
+def update_message_content(message: dict, content: str, append: bool = True) -> dict:
+ if isinstance(message["content"], list):
+ for item in message["content"]:
+ if item["type"] == "text":
+ if append:
+ item["text"] = f"{item['text']}\n{content}"
+ else:
+ item["text"] = f"{content}\n{item['text']}"
+ else:
+ if append:
+ message["content"] = f"{message['content']}\n{content}"
+ else:
+ message["content"] = f"{content}\n{message['content']}"
+ return message
def add_or_update_system_message(
@@ -148,10 +149,7 @@ def add_or_update_system_message(
"""
if messages and messages[0].get("role") == "system":
- if append:
- messages[0]["content"] = f"{messages[0]['content']}\n{content}"
- else:
- messages[0]["content"] = f"{content}\n{messages[0]['content']}"
+ messages[0] = update_message_content(messages[0], content, append)
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
@@ -159,7 +157,7 @@ def add_or_update_system_message(
return messages
-def add_or_update_user_message(content: str, messages: list[dict]):
+def add_or_update_user_message(content: str, messages: list[dict], append: bool = True):
"""
Adds a new user message at the end of the messages list
or updates the existing user message at the end.
@@ -170,7 +168,7 @@ def add_or_update_user_message(content: str, messages: list[dict]):
"""
if messages and messages[-1].get("role") == "user":
- messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
+ messages[-1] = update_message_content(messages[-1], content, append)
else:
# Insert at the end
messages.append({"role": "user", "content": content})
@@ -178,6 +176,16 @@ def add_or_update_user_message(content: str, messages: list[dict]):
return messages
+def prepend_to_first_user_message_content(
+ content: str, messages: list[dict]
+) -> list[dict]:
+ for message in messages:
+ if message["role"] == "user":
+ message = update_message_content(message, content, append=False)
+ break
+ return messages
+
+
def append_or_update_assistant_message(content: str, messages: list[dict]):
"""
Adds a new assistant message at the end of the messages list
diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py
index ee3ba79990..9399241853 100644
--- a/backend/open_webui/utils/oauth.py
+++ b/backend/open_webui/utils/oauth.py
@@ -1,7 +1,9 @@
import base64
+import hashlib
import logging
import mimetypes
import sys
+import urllib
import uuid
import json
from datetime import datetime, timedelta
@@ -9,6 +11,9 @@ from datetime import datetime, timedelta
import re
import fnmatch
import time
+import secrets
+from cryptography.fernet import Fernet
+
import aiohttp
from authlib.integrations.starlette_client import OAuth
@@ -18,6 +23,7 @@ from fastapi import (
status,
)
from starlette.responses import RedirectResponse
+from typing import Optional
from open_webui.models.auths import Auths
@@ -56,11 +62,27 @@ from open_webui.env import (
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
ENABLE_OAUTH_ID_TOKEN_COOKIE,
+ OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
)
from open_webui.utils.misc import parse_duration
from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
+from mcp.shared.auth import (
+ OAuthClientMetadata,
+ OAuthMetadata,
+)
+
+
+class OAuthClientInformationFull(OAuthClientMetadata):
+ issuer: Optional[str] = None # URL of the OAuth server that issued this client
+
+ client_id: str
+ client_secret: str | None = None
+ client_id_issued_at: int | None = None
+ client_secret_expires_at: int | None = None
+
+
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@@ -89,6 +111,42 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
+FERNET = None
+
+if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44:
+ key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest()
+ OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes)
+else:
+ OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()
+
+try:
+ FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY)
+except Exception as e:
+ log.error(f"Error initializing Fernet with provided key: {e}")
+ raise
+
+
+def encrypt_data(data) -> str:
+ """Encrypt data for storage"""
+ try:
+ data_json = json.dumps(data)
+ encrypted = FERNET.encrypt(data_json.encode()).decode()
+ return encrypted
+ except Exception as e:
+ log.error(f"Error encrypting data: {e}")
+ raise
+
+
+def decrypt_data(data: str):
+ """Decrypt data from storage"""
+ try:
+ decrypted = FERNET.decrypt(data.encode()).decode()
+ return json.loads(decrypted)
+ except Exception as e:
+ log.error(f"Error decrypting data: {e}")
+ raise
+
+
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
"""
Check if a group name matches any blocked pattern.
@@ -133,6 +191,412 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool:
return False
+def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
+ parsed = urllib.parse.urlparse(server_url)
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
+ return parsed, base_url
+
+
+def get_discovery_urls(server_url) -> list[str]:
+ urls = []
+ parsed, base_url = get_parsed_and_base_url(server_url)
+
+ urls.append(
+ urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server")
+ )
+ urls.append(urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"))
+
+ return urls
+
+
+# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
+# This is not currently supported.
+async def get_oauth_client_info_with_dynamic_client_registration(
+ request,
+ client_id: str,
+ oauth_server_url: str,
+ oauth_server_key: Optional[str] = None,
+) -> OAuthClientInformationFull:
+ try:
+ oauth_server_metadata = None
+ oauth_server_metadata_url = None
+
+ redirect_base_url = (
+ str(request.app.state.config.WEBUI_URL or request.base_url)
+ ).rstrip("/")
+
+ oauth_client_metadata = OAuthClientMetadata(
+ client_name="Open WebUI",
+ redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
+ grant_types=["authorization_code", "refresh_token"],
+ response_types=["code"],
+ token_endpoint_auth_method="client_secret_post",
+ )
+
+ # Attempt to fetch OAuth server metadata to get registration endpoint & scopes
+ discovery_urls = get_discovery_urls(oauth_server_url)
+ for url in discovery_urls:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ url, ssl=AIOHTTP_CLIENT_SESSION_SSL
+ ) as oauth_server_metadata_response:
+ if oauth_server_metadata_response.status == 200:
+ try:
+ oauth_server_metadata = OAuthMetadata.model_validate(
+ await oauth_server_metadata_response.json()
+ )
+ oauth_server_metadata_url = url
+ if (
+ oauth_client_metadata.scope is None
+ and oauth_server_metadata.scopes_supported is not None
+ ):
+ oauth_client_metadata.scope = " ".join(
+ oauth_server_metadata.scopes_supported
+ )
+ break
+ except Exception as e:
+ log.error(f"Error parsing OAuth metadata from {url}: {e}")
+ continue
+
+ registration_url = None
+ if oauth_server_metadata and oauth_server_metadata.registration_endpoint:
+ registration_url = str(oauth_server_metadata.registration_endpoint)
+ else:
+ _, base_url = get_parsed_and_base_url(oauth_server_url)
+ registration_url = urllib.parse.urljoin(base_url, "/register")
+
+ registration_data = oauth_client_metadata.model_dump(
+ exclude_none=True,
+ mode="json",
+ by_alias=True,
+ )
+
+ # Perform dynamic client registration and return client info
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
+ ) as oauth_client_registration_response:
+ try:
+ registration_response_json = (
+ await oauth_client_registration_response.json()
+ )
+ oauth_client_info = OAuthClientInformationFull.model_validate(
+ {
+ **registration_response_json,
+ **{"issuer": oauth_server_metadata_url},
+ }
+ )
+ log.info(
+ f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}"
+ )
+ return oauth_client_info
+ except Exception as e:
+ error_text = None
+ try:
+ error_text = await oauth_client_registration_response.text()
+ log.error(
+ f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}"
+ )
+ except Exception as e:
+ pass
+
+ log.error(f"Error parsing client registration response: {e}")
+ raise Exception(
+ f"Dynamic client registration failed: {error_text}"
+ if error_text
+ else "Error parsing client registration response"
+ )
+ raise Exception("Dynamic client registration failed")
+ except Exception as e:
+ log.error(f"Exception during dynamic client registration: {e}")
+ raise e
+
+
+class OAuthClientManager:
+ def __init__(self, app):
+ self.oauth = OAuth()
+ self.app = app
+ self.clients = {}
+
+ def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
+ self.clients[client_id] = {
+ "client": self.oauth.register(
+ name=client_id,
+ client_id=oauth_client_info.client_id,
+ client_secret=oauth_client_info.client_secret,
+ client_kwargs=(
+ {"scope": oauth_client_info.scope}
+ if oauth_client_info.scope
+ else {}
+ ),
+ server_metadata_url=(
+ oauth_client_info.issuer if oauth_client_info.issuer else None
+ ),
+ ),
+ "client_info": oauth_client_info,
+ }
+ return self.clients[client_id]
+
+ def remove_client(self, client_id):
+ if client_id in self.clients:
+ del self.clients[client_id]
+ log.info(f"Removed OAuth client {client_id}")
+ return True
+
+ def get_client(self, client_id):
+ client = self.clients.get(client_id)
+ return client["client"] if client else None
+
+ def get_client_info(self, client_id):
+ client = self.clients.get(client_id)
+ return client["client_info"] if client else None
+
+ def get_server_metadata_url(self, client_id):
+ if client_id in self.clients:
+ client = self.clients[client_id]
+ return (
+ client.server_metadata_url
+ if hasattr(client, "server_metadata_url")
+ else None
+ )
+ return None
+
+ async def get_oauth_token(
+ self, user_id: str, client_id: str, force_refresh: bool = False
+ ):
+ """
+ Get a valid OAuth token for the user, automatically refreshing if needed.
+
+ Args:
+ user_id: The user ID
+ client_id: The OAuth client ID (provider)
+ force_refresh: Force token refresh even if current token appears valid
+
+ Returns:
+ dict: OAuth token data with access_token, or None if no valid token available
+ """
+ try:
+ # Get the OAuth session
+ session = OAuthSessions.get_session_by_provider_and_user_id(
+ client_id, user_id
+ )
+ if not session:
+ log.warning(
+ f"No OAuth session found for user {user_id}, client_id {client_id}"
+ )
+ return None
+
+ if force_refresh or datetime.now() + timedelta(
+ minutes=5
+ ) >= datetime.fromtimestamp(session.expires_at):
+ log.debug(
+ f"Token refresh needed for user {user_id}, client_id {session.provider}"
+ )
+ refreshed_token = await self._refresh_token(session)
+ if refreshed_token:
+ return refreshed_token
+ else:
+ log.warning(
+ f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
+ )
+ OAuthSessions.delete_session_by_id(session.id)
+ return None
+ return session.token
+
+ except Exception as e:
+ log.error(f"Error getting OAuth token for user {user_id}: {e}")
+ return None
+
+ async def _refresh_token(self, session) -> dict:
+ """
+ Refresh an OAuth token if needed, with concurrency protection.
+
+ Args:
+ session: The OAuth session object
+
+ Returns:
+ dict: Refreshed token data, or None if refresh failed
+ """
+ try:
+ # Perform the actual refresh
+ refreshed_token = await self._perform_token_refresh(session)
+
+ if refreshed_token:
+ # Update the session with new token data
+ session = OAuthSessions.update_session_by_id(
+ session.id, refreshed_token
+ )
+ log.info(f"Successfully refreshed token for session {session.id}")
+ return session.token
+ else:
+ log.error(f"Failed to refresh token for session {session.id}")
+ return None
+
+ except Exception as e:
+ log.error(f"Error refreshing token for session {session.id}: {e}")
+ return None
+
+ async def _perform_token_refresh(self, session) -> dict:
+ """
+ Perform the actual OAuth token refresh.
+
+ Args:
+ session: The OAuth session object
+
+ Returns:
+ dict: New token data, or None if refresh failed
+ """
+ client_id = session.provider
+ token_data = session.token
+
+ if not token_data.get("refresh_token"):
+ log.warning(f"No refresh token available for session {session.id}")
+ return None
+
+ try:
+ client = self.get_client(client_id)
+ if not client:
+ log.error(f"No OAuth client found for provider {client_id}")
+ return None
+
+ token_endpoint = None
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
+ async with session_http.get(
+ self.get_server_metadata_url(client_id)
+ ) as r:
+ if r.status == 200:
+ openid_data = await r.json()
+ token_endpoint = openid_data.get("token_endpoint")
+ else:
+ log.error(
+ f"Failed to fetch OpenID configuration for client_id {client_id}"
+ )
+ if not token_endpoint:
+ log.error(f"No token endpoint found for client_id {client_id}")
+ return None
+
+ # Prepare refresh request
+ refresh_data = {
+ "grant_type": "refresh_token",
+ "refresh_token": token_data["refresh_token"],
+ "client_id": client.client_id,
+ }
+ if hasattr(client, "client_secret") and client.client_secret:
+ refresh_data["client_secret"] = client.client_secret
+
+ # Make refresh request
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
+ async with session_http.post(
+ token_endpoint,
+ data=refresh_data,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ ssl=AIOHTTP_CLIENT_SESSION_SSL,
+ ) as r:
+ if r.status == 200:
+ new_token_data = await r.json()
+
+ # Merge with existing token data (preserve refresh_token if not provided)
+ if "refresh_token" not in new_token_data:
+ new_token_data["refresh_token"] = token_data[
+ "refresh_token"
+ ]
+
+ # Add timestamp for tracking
+ new_token_data["issued_at"] = datetime.now().timestamp()
+
+ # Calculate expires_at if we have expires_in
+ if (
+ "expires_in" in new_token_data
+ and "expires_at" not in new_token_data
+ ):
+ new_token_data["expires_at"] = int(
+ datetime.now().timestamp()
+ + new_token_data["expires_in"]
+ )
+
+ log.debug(f"Token refresh successful for client_id {client_id}")
+ return new_token_data
+ else:
+ error_text = await r.text()
+ log.error(
+ f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}"
+ )
+ return None
+
+ except Exception as e:
+ log.error(f"Exception during token refresh for client_id {client_id}: {e}")
+ return None
+
+ async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
+ client = self.get_client(client_id)
+ if client is None:
+ raise HTTPException(404)
+
+ client_info = self.get_client_info(client_id)
+ if client_info is None:
+ raise HTTPException(404)
+
+ redirect_uri = (
+ client_info.redirect_uris[0] if client_info.redirect_uris else None
+ )
+ return await client.authorize_redirect(request, str(redirect_uri))
+
+ async def handle_callback(self, request, client_id: str, user_id: str, response):
+ client = self.get_client(client_id)
+ if client is None:
+ raise HTTPException(404)
+
+ error_message = None
+ try:
+ token = await client.authorize_access_token(request)
+ if token:
+ try:
+ # Add timestamp for tracking
+ token["issued_at"] = datetime.now().timestamp()
+
+ # Calculate expires_at if we have expires_in
+ if "expires_in" in token and "expires_at" not in token:
+ token["expires_at"] = (
+ datetime.now().timestamp() + token["expires_in"]
+ )
+
+ # Clean up any existing sessions for this user/client_id first
+ sessions = OAuthSessions.get_sessions_by_user_id(user_id)
+ for session in sessions:
+ if session.provider == client_id:
+ OAuthSessions.delete_session_by_id(session.id)
+
+ session = OAuthSessions.create_session(
+ user_id=user_id,
+ provider=client_id,
+ token=token,
+ )
+ log.info(
+ f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
+ )
+ except Exception as e:
+ error_message = "Failed to store OAuth session server-side"
+ log.error(f"Failed to store OAuth session server-side: {e}")
+ else:
+ error_message = "Failed to obtain OAuth token"
+ log.warning(error_message)
+ except Exception as e:
+ error_message = "OAuth callback error"
+ log.warning(f"OAuth callback error: {e}")
+
+ redirect_url = (
+ str(request.app.state.config.WEBUI_URL or request.base_url)
+ ).rstrip("/")
+
+ if error_message:
+ log.debug(error_message)
+ redirect_url = f"{redirect_url}/?error={error_message}"
+ return RedirectResponse(url=redirect_url, headers=response.headers)
+
+ response = RedirectResponse(url=redirect_url, headers=response.headers)
+ return response
+
+
class OAuthManager:
def __init__(self, app):
self.oauth = OAuth()
@@ -191,8 +655,10 @@ class OAuthManager:
return refreshed_token
else:
log.warning(
- f"Token refresh failed for user {user_id}, provider {session.provider}"
+ f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
)
+ OAuthSessions.delete_session_by_id(session.id)
+
return None
return session.token
@@ -252,9 +718,10 @@ class OAuthManager:
log.error(f"No OAuth client found for provider {provider}")
return None
+ server_metadata_url = self.get_server_metadata_url(provider)
token_endpoint = None
async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.get(client.gserver_metadata_url) as r:
+ async with session_http.get(server_metadata_url) as r:
if r.status == 200:
openid_data = await r.json()
token_endpoint = openid_data.get("token_endpoint")
@@ -301,7 +768,7 @@ class OAuthManager:
"expires_in" in new_token_data
and "expires_at" not in new_token_data
):
- new_token_data["expires_at"] = (
+ new_token_data["expires_at"] = int(
datetime.now().timestamp()
+ new_token_data["expires_in"]
)
@@ -574,7 +1041,7 @@ class OAuthManager:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
- "oauth_callback", provider=provider
+ "oauth_login_callback", provider=provider
)
client = self.get_client(provider)
if client is None:
@@ -791,9 +1258,9 @@ class OAuthManager:
else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
)
- redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
- if redirect_base_url.endswith("/"):
- redirect_base_url = redirect_base_url[:-1]
+ redirect_base_url = (
+ str(request.app.state.config.WEBUI_URL or request.base_url)
+ ).rstrip("/")
redirect_url = f"{redirect_base_url}/auth"
if error_message:
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 1b14ac1429..23bb7710f0 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -9,13 +9,14 @@ python-jose==3.4.0
passlib[bcrypt]==1.7.4
cryptography
-requests==2.32.4
+requests==2.32.5
aiohttp==3.12.15
async-timeout
aiocache
aiofiles
starlette-compress==1.6.0
httpx[socks,http2,zstd,cli,brotli]==0.28.1
+starsessions[redis]==2.2.1
sqlalchemy==2.0.38
alembic==1.14.0
@@ -43,13 +44,13 @@ asgiref==3.8.1
# AI libraries
openai
anthropic
-google-genai==1.32.0
+google-genai==1.38.0
google-generativeai==0.8.5
tiktoken
mcp==1.14.1
-langchain==0.3.26
-langchain-community==0.3.27
+langchain==0.3.27
+langchain-community==0.3.29
fake-useragent==2.2.0
chromadb==1.0.20
diff --git a/pyproject.toml b/pyproject.toml
index 09fcce07fb..aa2825fa69 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,7 @@ dependencies = [
"PyJWT[crypto]==2.10.1",
"authlib==1.6.3",
- "requests==2.32.4",
+ "requests==2.32.5",
"aiohttp==3.12.15",
"async-timeout",
"aiocache",
@@ -51,11 +51,11 @@ dependencies = [
"openai",
"anthropic",
- "google-genai==1.32.0",
+ "google-genai==1.38.0",
"google-generativeai==0.8.5",
- "langchain==0.3.26",
- "langchain-community==0.3.27",
+ "langchain==0.3.27",
+ "langchain-community==0.3.29",
"fake-useragent==2.2.0",
"chromadb==1.0.20",
diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts
index b1e7d5f23b..59d8600771 100644
--- a/src/lib/apis/chats/index.ts
+++ b/src/lib/apis/chats/index.ts
@@ -77,7 +77,11 @@ export const importChat = async (
return res;
};
-export const getChatList = async (token: string = '', page: number | null = null) => {
+export const getChatList = async (
+ token: string = '',
+ page: number | null = null,
+ include_folders: boolean = false
+) => {
let error = null;
const searchParams = new URLSearchParams();
@@ -85,6 +89,10 @@ export const getChatList = async (token: string = '', page: number | null = null
searchParams.append('page', `${page}`);
}
+ if (include_folders) {
+ searchParams.append('include_folders', 'true');
+ }
+
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/?${searchParams.toString()}`, {
method: 'GET',
headers: {
diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts
index ef983e63bf..c6cfdd2b2b 100644
--- a/src/lib/apis/configs/index.ts
+++ b/src/lib/apis/configs/index.ts
@@ -1,4 +1,4 @@
-import { WEBUI_API_BASE_URL } from '$lib/constants';
+import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import type { Banner } from '$lib/types';
export const importConfig = async (token: string, config) => {
@@ -202,6 +202,52 @@ export const verifyToolServerConnection = async (token: string, connection: obje
return res;
};
+type RegisterOAuthClientForm = {
+ url: string;
+ client_id: string;
+ client_name?: string;
+};
+
+export const registerOAuthClient = async (
+ token: string,
+ formData: RegisterOAuthClientForm,
+ type: null | string = null
+) => {
+ let error = null;
+
+ const searchParams = type ? `?type=${type}` : '';
+ const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register${searchParams}`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ Authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ ...formData
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.error(err);
+ error = err.detail;
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
+export const getOAuthClientAuthorizationUrl = (clientId: string, type: null | string = null) => {
+ const oauthClientId = type ? `${type}:${clientId}` : clientId;
+ return `${WEBUI_BASE_URL}/oauth/clients/${oauthClientId}/authorize`;
+};
+
export const getCodeExecutionConfig = async (token: string) => {
let error = null;
diff --git a/src/lib/components/AddToolServerModal.svelte b/src/lib/components/AddToolServerModal.svelte
index 01c87010ef..114c05f834 100644
--- a/src/lib/components/AddToolServerModal.svelte
+++ b/src/lib/components/AddToolServerModal.svelte
@@ -13,7 +13,7 @@
import Switch from '$lib/components/common/Switch.svelte';
import Tags from './common/Tags.svelte';
import { getToolServerData } from '$lib/apis';
- import { verifyToolServerConnection } from '$lib/apis/configs';
+ import { verifyToolServerConnection, registerOAuthClient } from '$lib/apis/configs';
import AccessControl from './workspace/common/AccessControl.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import XMark from '$lib/components/icons/XMark.svelte';
@@ -41,10 +41,47 @@
let name = '';
let description = '';
- let enable = true;
+ let oauthClientInfo = null;
+ let enable = true;
let loading = false;
+ const registerOAuthClientHandler = async () => {
+ if (url === '') {
+ toast.error($i18n.t('Please enter a valid URL'));
+ return;
+ }
+
+ if (id === '') {
+ toast.error($i18n.t('Please enter a valid ID'));
+ return;
+ }
+
+ const res = await registerOAuthClient(
+ localStorage.token,
+ {
+ url: url,
+ client_id: id
+ },
+ 'mcp'
+ ).catch((err) => {
+ toast.error($i18n.t('Registration failed'));
+ return null;
+ });
+
+ if (res) {
+ toast.warning(
+ $i18n.t(
+ 'Please save the connection to persist the OAuth client information and do not change the ID'
+ )
+ );
+ toast.success($i18n.t('Registration successful'));
+
+ console.debug('Registration successful', res);
+ oauthClientInfo = res?.oauth_client_info ?? null;
+ }
+ };
+
const verifyHandler = async () => {
if (url === '') {
toast.error($i18n.t('Please enter a valid URL'));
@@ -106,6 +143,12 @@
return;
}
+ if (type === 'mcp' && auth_type === 'oauth_2.1' && !oauthClientInfo) {
+ toast.error($i18n.t('Please register the OAuth client'));
+ loading = false;
+ return;
+ }
+
const connection = {
url,
path,
@@ -119,7 +162,8 @@
info: {
id: id,
name: name,
- description: description
+ description: description,
+ ...(oauthClientInfo ? { oauth_client_info: oauthClientInfo } : {})
}
};
@@ -139,6 +183,7 @@
id = '';
name = '';
description = '';
+ oauthClientInfo = null;
enable = true;
accessControl = null;
@@ -156,6 +201,7 @@
id = connection.info?.id ?? '';
name = connection.info?.name ?? '';
description = connection.info?.description ?? '';
+ oauthClientInfo = connection.info?.oauth_client_info ?? null;
enable = connection.config?.enable ?? true;
accessControl = connection.config?.access_control ?? null;
@@ -227,25 +273,6 @@
{/if}
- {#if type === 'mcp'}
-