diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7e04470d0e..17ed1a98e9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,60 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.6.31] - 2025-09-25
+
+### Added
+
+- 🔌 MCP (streamable HTTP) server support was added alongside existing OpenAPI server integration, allowing users to connect both server types through an improved server configuration interface. [#15932](https://github.com/open-webui/open-webui/issues/15932) [#16651](https://github.com/open-webui/open-webui/pull/16651), [Commit](https://github.com/open-webui/open-webui/commit/fd7385c3921eb59af76a26f4c475aedb38ce2406), [Commit](https://github.com/open-webui/open-webui/commit/777e81f7a8aca957a359d51df8388e5af4721a68), [Commit](https://github.com/open-webui/open-webui/commit/de7f7b3d855641450f8e5aac34fbae0665e0b80e), [Commit](https://github.com/open-webui/open-webui/commit/f1bbf3a91e4713039364b790e886e59b401572d0), [Commit](https://github.com/open-webui/open-webui/commit/c55afc42559c32a6f0c8beb0f1bb18e9360ab8af), [Commit](https://github.com/open-webui/open-webui/commit/61f20acf61f4fe30c0e5b0180949f6e1a8cf6524)
+- 🔐 To enable MCP server authentication, OAuth 2.1 dynamic client registration was implemented with secure automatic client registration, encrypted session management, and seamless authentication flows. [Commit](https://github.com/open-webui/open-webui/commit/972be4eda5a394c111e849075f94099c9c0dd9aa), [Commit](https://github.com/open-webui/open-webui/commit/77e971dd9fbeee806e2864e686df5ec75e82104b), [Commit](https://github.com/open-webui/open-webui/commit/879abd7feea3692a2f157da4a458d30f27217508), [Commit](https://github.com/open-webui/open-webui/commit/422d38fd114b1ebd8a7dbb114d64e14791e67d7a), [Docs:#709](https://github.com/open-webui/docs/pull/709)
+- 🛠️ External & Built-In Tools can now support rich UI element embedding ([Docs](https://docs.openwebui.com/features/plugin/tools/development)), allowing tools to return HTML content and interactive iframes that display directly within chat conversations with configurable security settings. [Commit](https://github.com/open-webui/open-webui/commit/07c5b25bc8b63173f406feb3ba183d375fedee6a), [Commit](https://github.com/open-webui/open-webui/commit/a5d8882bba7933a2c2c31c0a1405aba507c370bb), [Commit](https://github.com/open-webui/open-webui/commit/7be5b7f50f498de97359003609fc5993a172f084), [Commit](https://github.com/open-webui/open-webui/commit/a89ffccd7e96705a4a40e845289f4fcf9c4ae596)
+- 📝 Note editor now supports drag-and-drop reordering of list items with visual drag handles, making list organization more intuitive and efficient. [Commit](https://github.com/open-webui/open-webui/commit/e4e97e727e9b4971f1c363b1280ca3a101599d88), [Commit](https://github.com/open-webui/open-webui/commit/aeb5288a3c7a6e9e0a47b807cc52f870c1b7dbe6)
+- 🔍 Search modal was enhanced with quick action buttons for starting new conversations and creating notes, with intelligent content pre-population from search queries. [Commit](https://github.com/open-webui/open-webui/commit/aa6f63a335e172fec1dc94b2056541f52c1167a6), [Commit](https://github.com/open-webui/open-webui/commit/612a52d7bb7dbe9fa0bbbc8ac0a552d2b9801146), [Commit](https://github.com/open-webui/open-webui/commit/b03529b006f3148e895b1094584e1ab129ecac5b)
+- 🛠️ Tool user valve configuration interface was added to the integrations menu, displaying clickable gear icon buttons with tooltips for tools that support user-specific settings, making personal tool configurations easily accessible. [Commit](https://github.com/open-webui/open-webui/commit/27d61307cdce97ed11a05ec13fc300249d6022cd)
+- 👥 Channel access control was enhanced to require write permissions for posting, editing, and deleting messages, while read-only users can view content but cannot contribute. [#17543](https://github.com/open-webui/open-webui/pull/17543)
+- 💬 Channel models now support image processing, allowing AI assistants to view and analyze images shared in conversation threads. [Commit](https://github.com/open-webui/open-webui/commit/9f0010e234a6f40782a66021435d3c02b9c23639)
+- 🌐 Attach Webpage button was added to the message input menu, providing a user-friendly modal interface for attaching web content and YouTube videos as an alternative to the existing URL syntax. [#17534](https://github.com/open-webui/open-webui/pull/17534)
+- 🔐 Redis session storage support was added for OAuth redirects, providing better state handling in multi-pod Kubernetes deployments and resolving CSRF mismatch errors. [#17223](https://github.com/open-webui/open-webui/pull/17223), [#15373](https://github.com/open-webui/open-webui/issues/15373)
+- 🔍 Ollama Cloud web search integration was added as a new search engine option, providing access to web search functionality through Ollama's cloud infrastructure. [Commit](https://github.com/open-webui/open-webui/commit/e06489d92baca095b8f376fbef223298c7772579), [Commit](https://github.com/open-webui/open-webui/commit/4b6d34438bcfc45463dc7a9cb984794b32c1f0a1), [Commit](https://github.com/open-webui/open-webui/commit/05c46008da85357dc6890b846789dfaa59f4a520), [Commit](https://github.com/open-webui/open-webui/commit/fe65fe0b97ec5a8fff71592ff04a25c8e123d108), [Docs:#708](https://github.com/open-webui/docs/pull/708)
+- 🔍 Perplexity Websearch API integration was added as a new search engine option, providing access to the new websearch functionality provided by Perplexity. [#17756](https://github.com/open-webui/open-webui/issues/17756), [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/7f411dd5cc1c29733216f79e99eeeed0406a2afe)
+- ☁️ OneDrive integration was improved to support separate client IDs for personal and business authentication, enabling both integrations to work simultaneously. [#17619](https://github.com/open-webui/open-webui/pull/17619), [Docs](https://docs.openwebui.com/tutorials/integrations/onedrive-sharepoint), [Docs](https://docs.openwebui.com/getting-started/env-configuration/#onedrive)
+- 📝 Pending user overlay content now supports markdown formatting, enabling rich text display for custom messages similar to banner functionality. [#17681](https://github.com/open-webui/open-webui/pull/17681)
+- 🎨 Image generation model selection was centralized to enable dynamic model override in function calls, allowing pipes and tools to specify different models than the global default while maintaining backward compatibility. [#17689](https://github.com/open-webui/open-webui/pull/17689)
+- 🎨 Interface design was modernized with updated visual styling, improved spacing, and refined component layouts across modals, sidebar, settings, and navigation elements. [Commit](https://github.com/open-webui/open-webui/commit/27a91cc80a24bda0a3a188bc3120a8ab57b00881), [Commit](https://github.com/open-webui/open-webui/commit/4ad743098615f9c58daa9df392f31109aeceeb16), [Commit](https://github.com/open-webui/open-webui/commit/fd7385c3921eb59af76a26f4c475aedb38ce2406)
+- 📊 Notes query performance was optimized through database-level filtering and separated access control logic, reducing memory usage and eliminating N+1 query problems for better scalability. [#17607](https://github.com/open-webui/open-webui/pull/17607) [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/da661756fa7eec754270e6dd8c67cbf74a28a17f)
+- ⚡ Page loading performance was optimized by deferring API requests until components are actually opened, including ChangelogModal, ModelSelector, RecursiveFolder, ArchivedChatsModal, and SearchModal. [#17542](https://github.com/open-webui/open-webui/pull/17542), [#17555](https://github.com/open-webui/open-webui/pull/17555), [#17557](https://github.com/open-webui/open-webui/pull/17557), [#17541](https://github.com/open-webui/open-webui/pull/17541), [#17640](https://github.com/open-webui/open-webui/pull/17640)
+- ⚡ Bundle size was reduced by 1.58MB through optimized highlight.js language support, improving page loading speed and reducing bandwidth usage. [#17645](https://github.com/open-webui/open-webui/pull/17645)
+- ⚡ Editor collaboration functionality was refactored to reduce package size by 390KB and minimize compilation errors, improving build performance and reliability. [#17593](https://github.com/open-webui/open-webui/pull/17593)
+- ♿ Enhanced user interface accessibility through the addition of unique element IDs, improving targeting for testing, styling, and assistive technologies while providing better semantic markup for screen readers and accessibility tools. [#17746](https://github.com/open-webui/open-webui/pull/17746)
+- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
+- 🌐 Translations for Portuguese (Brazil), Chinese (Simplified and Traditional), Korean, Irish, Spanish, Finnish, French, Kabyle, Russian, and Catalan were enhanced and improved.
+
+### Fixed
+
+- 🛡️ SVG content security was enhanced by implementing DOMPurify sanitization to prevent XSS attacks through malicious SVG elements, ensuring safe rendering of user-generated SVG content. [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/750a659a9fee7687e667d9d755e17b8a0c77d557)
+- ☁️ OneDrive attachment menu rendering issues were resolved by restructuring the submenu interface from dropdown to tabbed navigation, preventing menu items from being hidden or clipped due to overflow constraints. [#17554](https://github.com/open-webui/open-webui/issues/17554), [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/90e4b49b881b644465831cc3028bb44f0f7a2196)
+- 💬 Attached conversation references now persist throughout the entire chat session, ensuring models can continue querying referenced conversations after multiple conversation turns. [#17750](https://github.com/open-webui/open-webui/issues/17750)
+- 🔍 Search modal text box focus issues after pinning or unpinning chats were resolved, allowing users to properly exit the search interface by clicking outside the text box. [#17743](https://github.com/open-webui/open-webui/issues/17743)
+- 🔍 Search function chat list is now properly updated in real-time when chats are created or deleted, eliminating stale search results and preview loading failures. [#17741](https://github.com/open-webui/open-webui/issues/17741)
+- 💬 Chat jitter and delayed code block expansion in multi-model sessions were resolved by reverting dynamic CodeEditor loading, restoring stable rendering behavior. [#17715](https://github.com/open-webui/open-webui/pull/17715), [#17684](https://github.com/open-webui/open-webui/issues/17684)
+- 📎 File upload handling was improved to properly recognize uploaded files even when no accompanying text message is provided, resolving issues where attachments were ignored in custom prompts. [#17492](https://github.com/open-webui/open-webui/issues/17492)
+- 💬 Chat conversation referencing within projects was restored by including foldered chats in the reference menu, allowing users to properly quote conversations from within their project scope. [#17530](https://github.com/open-webui/open-webui/issues/17530)
+- 🔍 RAG query generation is now skipped when all attached files are set to full context mode, preventing unnecessary retrieval operations and improving system efficiency. [#17744](https://github.com/open-webui/open-webui/pull/17744)
+- 💾 Memory leaks in file handling and HTTP connections are prevented through proper resource cleanup, ensuring stable memory usage during large file downloads and processing operations. [#17608](https://github.com/open-webui/open-webui/pull/17608)
+- 🔐 OAuth access token refresh errors are resolved by properly implementing async/await patterns, preventing "coroutine object has no attribute get" failures during token expiry. [#17585](https://github.com/open-webui/open-webui/issues/17585), [#17678](https://github.com/open-webui/open-webui/issues/17678)
+- ⚙️ Valve behavior was improved to properly handle default values and array types, ensuring only explicitly set values are persisted while maintaining consistent distinction between custom and default valve states. [#17664](https://github.com/open-webui/open-webui/pull/17664)
+- 🔍 Hybrid search functionality was enhanced to handle inconsistent parameter types and prevent failures when collection results are None, empty, or in unexpected formats. [#17617](https://github.com/open-webui/open-webui/pull/17617)
+- 📁 Empty folder deletion is now allowed regardless of chat deletion permission restrictions, resolving cases where users couldn't remove folders after deleting all contained chats. [#17683](https://github.com/open-webui/open-webui/pull/17683)
+- 📝 Rich text editor console errors were resolved by adding proper error handling when the TipTap editor view is not available or not yet mounted. [#17697](https://github.com/open-webui/open-webui/issues/17697)
+- 🗒️ Hidden models are now properly excluded from the notes section dropdown and default model selection, preventing users from accessing models they shouldn't see. [#17722](https://github.com/open-webui/open-webui/pull/17722)
+- 🖼️ AI-generated image download filenames now use a clean, translatable "Generated Image" format instead of potentially problematic response text, improving file management and compatibility. [#17721](https://github.com/open-webui/open-webui/pull/17721)
+- 🎨 Toggle switch display issues in the Integrations interface are fixed, preventing background highlighting and obscuring on hover. [#17564](https://github.com/open-webui/open-webui/issues/17564)
+
+### Changed
+
+- 👥 Channel permissions now require write access for message posting, editing, and deletion, with existing user groups defaulting to read-only access requiring manual admin migration to write permissions for full participation.
+- ☁️ OneDrive environment variable configuration was updated to use separate ONEDRIVE_CLIENT_ID_PERSONAL and ONEDRIVE_CLIENT_ID_BUSINESS variables for better client ID separation, while maintaining backward compatibility with the legacy ONEDRIVE_CLIENT_ID variable. [Docs](https://docs.openwebui.com/tutorials/integrations/onedrive-sharepoint), [Docs](https://docs.openwebui.com/getting-started/env-configuration/#onedrive)
+
## [0.6.30] - 2025-09-17
### Added
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index ca090efa22..7e5c35a451 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -222,10 +222,11 @@ class PersistentConfig(Generic[T]):
class AppConfig:
- _state: dict[str, PersistentConfig]
_redis: Union[redis.Redis, redis.cluster.RedisCluster] = None
_redis_key_prefix: str
+ _state: dict[str, PersistentConfig]
+
def __init__(
self,
redis_url: Optional[str] = None,
@@ -233,9 +234,8 @@ class AppConfig:
redis_cluster: Optional[bool] = False,
redis_key_prefix: str = "open-webui",
):
- super().__setattr__("_state", {})
- super().__setattr__("_redis_key_prefix", redis_key_prefix)
if redis_url:
+ super().__setattr__("_redis_key_prefix", redis_key_prefix)
super().__setattr__(
"_redis",
get_redis_connection(
@@ -246,6 +246,8 @@ class AppConfig:
),
)
+ super().__setattr__("_state", {})
+
def __setattr__(self, key, value):
if isinstance(value, PersistentConfig):
self._state[key] = value
@@ -2168,6 +2170,8 @@ ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig(
"onedrive.enable",
os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true",
)
+
+
ENABLE_ONEDRIVE_PERSONAL = (
os.environ.get("ENABLE_ONEDRIVE_PERSONAL", "True").lower() == "true"
)
@@ -2175,10 +2179,12 @@ ENABLE_ONEDRIVE_BUSINESS = (
os.environ.get("ENABLE_ONEDRIVE_BUSINESS", "True").lower() == "true"
)
-ONEDRIVE_CLIENT_ID = PersistentConfig(
- "ONEDRIVE_CLIENT_ID",
- "onedrive.client_id",
- os.environ.get("ONEDRIVE_CLIENT_ID", ""),
+ONEDRIVE_CLIENT_ID = os.environ.get("ONEDRIVE_CLIENT_ID", "")
+ONEDRIVE_CLIENT_ID_PERSONAL = os.environ.get(
+ "ONEDRIVE_CLIENT_ID_PERSONAL", ONEDRIVE_CLIENT_ID
+)
+ONEDRIVE_CLIENT_ID_BUSINESS = os.environ.get(
+ "ONEDRIVE_CLIENT_ID_BUSINESS", ONEDRIVE_CLIENT_ID
)
ONEDRIVE_SHAREPOINT_URL = PersistentConfig(
@@ -2761,6 +2767,12 @@ WEB_SEARCH_TRUST_ENV = PersistentConfig(
)
+OLLAMA_CLOUD_WEB_SEARCH_API_KEY = PersistentConfig(
+ "OLLAMA_CLOUD_WEB_SEARCH_API_KEY",
+ "rag.web.search.ollama_cloud_api_key",
+ os.getenv("OLLAMA_CLOUD_API_KEY", ""),
+)
+
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index b4fdc97d82..e02424f969 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -474,6 +474,10 @@ ENABLE_OAUTH_ID_TOKEN_COOKIE = (
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
)
+OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get(
+ "OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY
+)
+
OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
)
@@ -547,16 +551,16 @@ else:
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get(
- "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "10"
+ "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30"
)
if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "":
- CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 10
+ CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
else:
try:
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = int(CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES)
except Exception:
- CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 10
+ CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
####################################
diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py
index af6dd1ce1a..d102263cb3 100644
--- a/backend/open_webui/functions.py
+++ b/backend/open_webui/functions.py
@@ -239,7 +239,7 @@ async def generate_function_chat_completion(
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
- oauth_token = request.app.state.oauth_manager.get_oauth_token(
+ oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 5630a58839..f38bd47109 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -50,6 +50,11 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from starlette.datastructures import Headers
+from starsessions import (
+ SessionMiddleware as StarSessionsMiddleware,
+ SessionAutoloadMiddleware,
+)
+from starsessions.stores.redis import RedisStore
from open_webui.utils import logger
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
@@ -269,6 +274,7 @@ from open_webui.config import (
WEB_SEARCH_CONCURRENT_REQUESTS,
WEB_SEARCH_TRUST_ENV,
WEB_SEARCH_DOMAIN_FILTER_LIST,
+ OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
JINA_API_KEY,
SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE,
@@ -301,7 +307,8 @@ from open_webui.config import (
GOOGLE_DRIVE_CLIENT_ID,
GOOGLE_DRIVE_API_KEY,
ENABLE_ONEDRIVE_INTEGRATION,
- ONEDRIVE_CLIENT_ID,
+ ONEDRIVE_CLIENT_ID_PERSONAL,
+ ONEDRIVE_CLIENT_ID_BUSINESS,
ONEDRIVE_SHAREPOINT_URL,
ONEDRIVE_SHAREPOINT_TENANT_ID,
ENABLE_ONEDRIVE_PERSONAL,
@@ -466,7 +473,12 @@ from open_webui.utils.auth import (
get_verified_user,
)
from open_webui.utils.plugin import install_tool_and_function_dependencies
-from open_webui.utils.oauth import OAuthManager
+from open_webui.utils.oauth import (
+ OAuthManager,
+ OAuthClientManager,
+ decrypt_data,
+ OAuthClientInformationFull,
+)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
@@ -596,9 +608,14 @@ app = FastAPI(
lifespan=lifespan,
)
+# For Open WebUI OIDC/OAuth2
oauth_manager = OAuthManager(app)
app.state.oauth_manager = oauth_manager
+# For Integrations
+oauth_client_manager = OAuthClientManager(app)
+app.state.oauth_client_manager = oauth_client_manager
+
app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
@@ -882,6 +899,8 @@ app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
+
+app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = OLLAMA_CLOUD_WEB_SEARCH_API_KEY
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.YACY_QUERY_URL = YACY_QUERY_URL
app.state.config.YACY_USERNAME = YACY_USERNAME
@@ -1530,6 +1549,14 @@ async def chat_completion(
except:
pass
+ finally:
+ try:
+ if mcp_clients := metadata.get("mcp_clients"):
+ for client in mcp_clients:
+ await client.disconnect()
+ except Exception as e:
+ log.debug(f"Error cleaning up: {e}")
+ pass
if (
metadata.get("session_id")
@@ -1743,7 +1770,8 @@ async def get_app_config(request: Request):
"api_key": GOOGLE_DRIVE_API_KEY.value,
},
"onedrive": {
- "client_id": ONEDRIVE_CLIENT_ID.value,
+ "client_id_personal": ONEDRIVE_CLIENT_ID_PERSONAL,
+ "client_id_business": ONEDRIVE_CLIENT_ID_BUSINESS,
"sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value,
"sharepoint_tenant_id": ONEDRIVE_SHAREPOINT_TENANT_ID.value,
},
@@ -1863,14 +1891,78 @@ async def get_current_usage(user=Depends(get_verified_user)):
# OAuth Login & Callback
############################
+
+# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1
+if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
+ for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS:
+ if tool_server_connection.get("type", "openapi") == "mcp":
+ server_id = tool_server_connection.get("info", {}).get("id")
+ auth_type = tool_server_connection.get("auth_type", "none")
+ if server_id and auth_type == "oauth_2.1":
+ oauth_client_info = tool_server_connection.get("info", {}).get(
+ "oauth_client_info", ""
+ )
+
+ oauth_client_info = decrypt_data(oauth_client_info)
+ app.state.oauth_client_manager.add_client(
+ f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info)
+ )
+
+
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
- app.add_middleware(
- SessionMiddleware,
- secret_key=WEBUI_SECRET_KEY,
- session_cookie="oui-session",
- same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
- https_only=WEBUI_SESSION_COOKIE_SECURE,
+ try:
+ if REDIS_URL:
+ redis_session_store = RedisStore(
+ url=REDIS_URL,
+ prefix=(
+ f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:"
+ ),
+ )
+
+ app.add_middleware(SessionAutoloadMiddleware)
+ app.add_middleware(
+ StarSessionsMiddleware,
+ store=redis_session_store,
+ cookie_name="oui-session",
+ cookie_same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
+ cookie_https_only=WEBUI_SESSION_COOKIE_SECURE,
+ )
+ log.info("Using Redis for session")
+ else:
+ raise ValueError("No Redis URL provided")
+ except Exception as e:
+ app.add_middleware(
+ SessionMiddleware,
+ secret_key=WEBUI_SECRET_KEY,
+ session_cookie="oui-session",
+ same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
+ https_only=WEBUI_SESSION_COOKIE_SECURE,
+ )
+
+
+@app.get("/oauth/clients/{client_id}/authorize")
+async def oauth_client_authorize(
+ client_id: str,
+ request: Request,
+ response: Response,
+ user=Depends(get_verified_user),
+):
+ return await oauth_client_manager.handle_authorize(request, client_id=client_id)
+
+
+@app.get("/oauth/clients/{client_id}/callback")
+async def oauth_client_callback(
+ client_id: str,
+ request: Request,
+ response: Response,
+ user=Depends(get_verified_user),
+):
+ return await oauth_client_manager.handle_callback(
+ request,
+ client_id=client_id,
+ user_id=user.id if user else None,
+ response=response,
)
@@ -1885,8 +1977,9 @@ async def oauth_login(provider: str, request: Request):
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is already taken
-@app.get("/oauth/{provider}/callback")
-async def oauth_callback(provider: str, request: Request, response: Response):
+@app.get("/oauth/{provider}/login/callback")
+@app.get("/oauth/{provider}/callback") # Legacy endpoint
+async def oauth_login_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(request, provider, response)
diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py
index 92f238c3a0..e75266be78 100644
--- a/backend/open_webui/models/channels.py
+++ b/backend/open_webui/models/channels.py
@@ -57,6 +57,10 @@ class ChannelModel(BaseModel):
####################
+class ChannelResponse(ChannelModel):
+ write_access: bool = False
+
+
class ChannelForm(BaseModel):
name: str
description: Optional[str] = None
diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py
index cadb5a3a79..97fd9b6256 100644
--- a/backend/open_webui/models/chats.py
+++ b/backend/open_webui/models/chats.py
@@ -492,11 +492,16 @@ class ChatTable:
self,
user_id: str,
include_archived: bool = False,
+ include_folders: bool = False,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
- query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
+ query = db.query(Chat).filter_by(user_id=user_id)
+
+ if not include_folders:
+ query = query.filter_by(folder_id=None)
+
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived:
@@ -943,6 +948,16 @@ class ChatTable:
return count
+ def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int:
+ with get_db() as db:
+ query = db.query(Chat).filter_by(user_id=user_id)
+
+ query = query.filter_by(folder_id=folder_id)
+ count = query.count()
+
+ log.info(f"Count of chats for folder '{folder_id}': {count}")
+ return count
+
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py
index 57978225d4..bf07b5f86f 100644
--- a/backend/open_webui/models/files.py
+++ b/backend/open_webui/models/files.py
@@ -130,6 +130,17 @@ class FilesTable:
except Exception:
return None
+ def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
+ with get_db() as db:
+ try:
+ file = db.query(File).filter_by(id=id, user_id=user_id).first()
+ if file:
+ return FileModel.model_validate(file)
+ else:
+ return None
+ except Exception:
+ return None
+
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
try:
diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py
index b61e820eae..f1b11f071e 100644
--- a/backend/open_webui/models/notes.py
+++ b/backend/open_webui/models/notes.py
@@ -2,6 +2,7 @@ import json
import time
import uuid
from typing import Optional
+from functools import lru_cache
from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
@@ -110,20 +111,72 @@ class NoteTable:
return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_user_id(
+ self,
+ user_id: str,
+ skip: Optional[int] = None,
+ limit: Optional[int] = None,
+ ) -> list[NoteModel]:
+ with get_db() as db:
+ query = db.query(Note).filter(Note.user_id == user_id)
+ query = query.order_by(Note.updated_at.desc())
+
+ if skip is not None:
+ query = query.offset(skip)
+ if limit is not None:
+ query = query.limit(limit)
+
+ notes = query.all()
+ return [NoteModel.model_validate(note) for note in notes]
+
+ def get_notes_by_permission(
self,
user_id: str,
permission: str = "write",
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[NoteModel]:
- notes = self.get_notes(skip=skip, limit=limit)
- user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
- return [
- note
- for note in notes
- if note.user_id == user_id
- or has_access(user_id, permission, note.access_control, user_group_ids)
- ]
+ with get_db() as db:
+ user_groups = Groups.get_groups_by_member_id(user_id)
+ user_group_ids = {group.id for group in user_groups}
+
+ # Order newest-first. We stream to keep memory usage low.
+ query = (
+ db.query(Note)
+ .order_by(Note.updated_at.desc())
+ .execution_options(stream_results=True)
+ .yield_per(256)
+ )
+
+ results: list[NoteModel] = []
+ n_skipped = 0
+
+ for note in query:
+ # Fast-pass #1: owner
+ if note.user_id == user_id:
+ permitted = True
+ # Fast-pass #2: public/open
+ elif note.access_control is None:
+ # Technically this should mean public access for both read and write, but we'll only do read for now
+ # We might want to change this behavior later
+ permitted = permission == "read"
+ else:
+ permitted = has_access(
+ user_id, permission, note.access_control, user_group_ids
+ )
+
+ if not permitted:
+ continue
+
+ # Apply skip AFTER permission filtering so it counts only accessible notes
+ if skip and n_skipped < skip:
+ n_skipped += 1
+ continue
+
+ results.append(NoteModel.model_validate(note))
+ if limit is not None and len(results) >= limit:
+ break
+
+ return results
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db:
diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py
index 9fd5335ce5..81ce220384 100644
--- a/backend/open_webui/models/oauth_sessions.py
+++ b/backend/open_webui/models/oauth_sessions.py
@@ -176,6 +176,26 @@ class OAuthSessionTable:
log.error(f"Error getting OAuth session by ID: {e}")
return None
+ def get_session_by_provider_and_user_id(
+ self, provider: str, user_id: str
+ ) -> Optional[OAuthSessionModel]:
+ """Get OAuth session by provider and user ID"""
+ try:
+ with get_db() as db:
+ session = (
+ db.query(OAuthSession)
+ .filter_by(provider=provider, user_id=user_id)
+ .first()
+ )
+ if session:
+ session.token = self._decrypt_token(session.token)
+ return OAuthSessionModel.model_validate(session)
+
+ return None
+ except Exception as e:
+ log.error(f"Error getting OAuth session by provider and user ID: {e}")
+ return None
+
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py
index 3a47fa008d..48f84b3ac4 100644
--- a/backend/open_webui/models/tools.py
+++ b/backend/open_webui/models/tools.py
@@ -95,6 +95,8 @@ class ToolResponse(BaseModel):
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
+ model_config = ConfigDict(extra="allow")
+
class ToolForm(BaseModel):
id: str
diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py
index aec8de6846..65da1592e1 100644
--- a/backend/open_webui/retrieval/utils.py
+++ b/backend/open_webui/retrieval/utils.py
@@ -127,7 +127,13 @@ def query_doc_with_hybrid_search(
hybrid_bm25_weight: float,
) -> dict:
try:
- if not collection_result.documents[0]:
+ if (
+ not collection_result
+ or not hasattr(collection_result, "documents")
+ or not collection_result.documents
+ or len(collection_result.documents) == 0
+ or not collection_result.documents[0]
+ ):
log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
return {"documents": [], "metadatas": [], "distances": []}
diff --git a/backend/open_webui/retrieval/web/ollama.py b/backend/open_webui/retrieval/web/ollama.py
new file mode 100644
index 0000000000..a199a14389
--- /dev/null
+++ b/backend/open_webui/retrieval/web/ollama.py
@@ -0,0 +1,51 @@
+import logging
+from dataclasses import dataclass
+from typing import Optional
+
+import requests
+from open_webui.env import SRC_LOG_LEVELS
+from open_webui.retrieval.web.main import SearchResult
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_ollama_cloud(
+ url: str,
+ api_key: str,
+ query: str,
+ count: int,
+ filter_list: Optional[list[str]] = None,
+) -> list[SearchResult]:
+ """Search using Ollama Search API and return the results as a list of SearchResult objects.
+
+ Args:
+ api_key (str): A Ollama Search API key
+ query (str): The query to search for
+ count (int): Number of results to return
+ filter_list (Optional[list[str]]): List of domains to filter results by
+ """
+ log.info(f"Searching with Ollama for query: {query}")
+
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+ payload = {"query": query, "max_results": count}
+
+ try:
+ response = requests.post(f"{url}/api/web_search", headers=headers, json=payload)
+ response.raise_for_status()
+ data = response.json()
+
+ results = data.get("results", [])
+ log.info(f"Found {len(results)} results")
+
+ return [
+ SearchResult(
+ link=result.get("url", ""),
+ title=result.get("title", ""),
+ snippet=result.get("content", ""),
+ )
+ for result in results
+ ]
+ except Exception as e:
+ log.error(f"Error searching Ollama: {e}")
+ return []
diff --git a/backend/open_webui/retrieval/web/perplexity_search.py b/backend/open_webui/retrieval/web/perplexity_search.py
new file mode 100644
index 0000000000..e3e0caa2b3
--- /dev/null
+++ b/backend/open_webui/retrieval/web/perplexity_search.py
@@ -0,0 +1,64 @@
+import logging
+from typing import Optional, Literal
+import requests
+
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_perplexity_search(
+ api_key: str,
+ query: str,
+ count: int,
+ filter_list: Optional[list[str]] = None,
+) -> list[SearchResult]:
+ """Search using Perplexity API and return the results as a list of SearchResult objects.
+
+ Args:
+ api_key (str): A Perplexity API key
+ query (str): The query to search for
+ count (int): Maximum number of results to return
+ filter_list (Optional[list[str]]): List of domains to filter results
+
+ """
+
+ # Handle PersistentConfig object
+ if hasattr(api_key, "__str__"):
+ api_key = str(api_key)
+
+ try:
+ url = "https://api.perplexity.ai/search"
+
+ # Create payload for the API call
+ payload = {
+ "query": query,
+ "max_results": count,
+ }
+
+ headers = {
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ }
+
+ # Make the API request
+ response = requests.request("POST", url, json=payload, headers=headers)
+ # Parse the JSON response
+ json_response = response.json()
+
+ # Extract citations from the response
+ results = json_response.get("results", [])
+
+ return [
+ SearchResult(
+ link=result["url"], title=result["title"], snippet=result["snippet"]
+ )
+ for result in results
+ ]
+
+ except Exception as e:
+ log.error(f"Error searching with Perplexity Search API: {e}")
+ return []
diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py
index da52be6e79..e7b8366347 100644
--- a/backend/open_webui/routers/channels.py
+++ b/backend/open_webui/routers/channels.py
@@ -10,7 +10,13 @@ from pydantic import BaseModel
from open_webui.socket.main import sio, get_user_ids_from_room
from open_webui.models.users import Users, UserNameResponse
-from open_webui.models.channels import Channels, ChannelModel, ChannelForm
+from open_webui.models.groups import Groups
+from open_webui.models.channels import (
+ Channels,
+ ChannelModel,
+ ChannelForm,
+ ChannelResponse,
+)
from open_webui.models.messages import (
Messages,
MessageModel,
@@ -80,7 +86,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
############################
-@router.get("/{id}", response_model=Optional[ChannelModel])
+@router.get("/{id}", response_model=Optional[ChannelResponse])
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
channel = Channels.get_channel_by_id(id)
if not channel:
@@ -95,7 +101,16 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
- return ChannelModel(**channel.model_dump())
+ write_access = has_access(
+ user.id, type="write", access_control=channel.access_control, strict=False
+ )
+
+ return ChannelResponse(
+ **{
+ **channel.model_dump(),
+ "write_access": write_access or user.role == "admin",
+ }
+ )
############################
@@ -275,6 +290,7 @@ async def model_response_handler(request, channel, message, user):
)
thread_history = []
+ images = []
message_users = {}
for thread_message in thread_messages:
@@ -303,6 +319,11 @@ async def model_response_handler(request, channel, message, user):
f"{username}: {replace_mentions(thread_message.content)}"
)
+ thread_message_files = thread_message.data.get("files", [])
+ for file in thread_message_files:
+ if file.get("type", "") == "image":
+ images.append(file.get("url", ""))
+
system_message = {
"role": "system",
"content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational."
@@ -313,14 +334,29 @@ async def model_response_handler(request, channel, message, user):
),
}
+ content = f"{user.name if user else 'User'}: {message_content}"
+ if images:
+ content = [
+ {
+ "type": "text",
+ "text": content,
+ },
+ *[
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": image,
+ },
+ }
+ for image in images
+ ],
+ ]
+
form_data = {
"model": model_id,
"messages": [
system_message,
- {
- "role": "user",
- "content": f"{user.name if user else 'User'}: {message_content}",
- },
+ {"role": "user", "content": content},
],
"stream": False,
}
@@ -362,7 +398,7 @@ async def new_message_handler(
)
if user.role != "admin" and not has_access(
- user.id, type="read", access_control=channel.access_control
+ user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@@ -658,7 +694,7 @@ async def add_reaction_to_message(
)
if user.role != "admin" and not has_access(
- user.id, type="read", access_control=channel.access_control
+ user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@@ -724,7 +760,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
)
if user.role != "admin" and not has_access(
- user.id, type="read", access_control=channel.access_control
+ user.id, type="write", access_control=channel.access_control, strict=False
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@@ -806,7 +842,9 @@ async def delete_message_by_id(
if (
user.role != "admin"
and message.user_id != user.id
- and not has_access(user.id, type="read", access_control=channel.access_control)
+ and not has_access(
+ user.id, type="write", access_control=channel.access_control, strict=False
+ )
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py
index 847368412e..788e355f2b 100644
--- a/backend/open_webui/routers/chats.py
+++ b/backend/open_webui/routers/chats.py
@@ -37,7 +37,9 @@ router = APIRouter()
@router.get("/", response_model=list[ChatTitleIdResponse])
@router.get("/list", response_model=list[ChatTitleIdResponse])
def get_session_user_chat_list(
- user=Depends(get_verified_user), page: Optional[int] = None
+ user=Depends(get_verified_user),
+ page: Optional[int] = None,
+ include_folders: Optional[bool] = False,
):
try:
if page is not None:
@@ -45,10 +47,12 @@ def get_session_user_chat_list(
skip = (page - 1) * limit
return Chats.get_chat_title_id_list_by_user_id(
- user.id, skip=skip, limit=limit
+ user.id, include_folders=include_folders, skip=skip, limit=limit
)
else:
- return Chats.get_chat_title_id_list_by_user_id(user.id)
+ return Chats.get_chat_title_id_list_by_user_id(
+ user.id, include_folders=include_folders
+ )
except Exception as e:
log.exception(e)
raise HTTPException(
diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py
index 8ce4e0d247..d4b88032e2 100644
--- a/backend/open_webui/routers/configs.py
+++ b/backend/open_webui/routers/configs.py
@@ -1,5 +1,7 @@
+import logging
from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict
+import aiohttp
from typing import Optional
@@ -12,10 +14,24 @@ from open_webui.utils.tools import (
get_tool_server_url,
set_tool_servers,
)
+from open_webui.utils.mcp.client import MCPClient
+from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.utils.oauth import (
+ get_discovery_urls,
+ get_oauth_client_info_with_dynamic_client_registration,
+ encrypt_data,
+ decrypt_data,
+ OAuthClientInformationFull,
+)
+from mcp.shared.auth import OAuthMetadata
router = APIRouter()
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
############################
# ImportConfig
@@ -79,6 +95,43 @@ async def set_connections_config(
}
+class OAuthClientRegistrationForm(BaseModel):
+ url: str
+ client_id: str
+ client_name: Optional[str] = None
+
+
+@router.post("/oauth/clients/register")
+async def register_oauth_client(
+ request: Request,
+ form_data: OAuthClientRegistrationForm,
+ type: Optional[str] = None,
+ user=Depends(get_admin_user),
+):
+ try:
+ oauth_client_id = form_data.client_id
+ if type:
+ oauth_client_id = f"{type}:{form_data.client_id}"
+
+ oauth_client_info = (
+ await get_oauth_client_info_with_dynamic_client_registration(
+ request, oauth_client_id, form_data.url
+ )
+ )
+ return {
+ "status": True,
+ "oauth_client_info": encrypt_data(
+ oauth_client_info.model_dump(mode="json")
+ ),
+ }
+ except Exception as e:
+ log.debug(f"Failed to register OAuth client: {e}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to register OAuth client",
+ )
+
+
############################
# ToolServers Config
############################
@@ -87,6 +140,7 @@ async def set_connections_config(
class ToolServerConnection(BaseModel):
url: str
path: str
+ type: Optional[str] = "openapi" # openapi, mcp
auth_type: Optional[str]
key: Optional[str]
config: Optional[dict]
@@ -114,8 +168,29 @@ async def set_tool_servers_config(
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
+
await set_tool_servers(request)
+ for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+ server_type = connection.get("type", "openapi")
+ if server_type == "mcp":
+ server_id = connection.get("info", {}).get("id")
+ auth_type = connection.get("auth_type", "none")
+ if auth_type == "oauth_2.1" and server_id:
+ try:
+ oauth_client_info = connection.get("info", {}).get(
+ "oauth_client_info", ""
+ )
+ oauth_client_info = decrypt_data(oauth_client_info)
+
+ await request.app.state.oauth_client_manager.add_client(
+ f"{server_type}:{server_id}",
+ OAuthClientInformationFull(**oauth_client_info),
+ )
+ except Exception as e:
+ log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
+ continue
+
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@@ -129,19 +204,105 @@ async def verify_tool_servers_config(
Verify the connection to the tool server.
"""
try:
+ if form_data.type == "mcp":
+ if form_data.auth_type == "oauth_2.1":
+ discovery_urls = get_discovery_urls(form_data.url)
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ discovery_urls[0]
+ ) as oauth_server_metadata_response:
+ if oauth_server_metadata_response.status != 200:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}",
+ )
- token = None
- if form_data.auth_type == "bearer":
- token = form_data.key
- elif form_data.auth_type == "session":
- token = request.state.token.credentials
+ try:
+ oauth_server_metadata = OAuthMetadata.model_validate(
+ await oauth_server_metadata_response.json()
+ )
+ return {
+ "status": True,
+ "oauth_server_metadata": oauth_server_metadata.model_dump(
+ mode="json"
+ ),
+ }
+ except Exception as e:
+ log.info(
+ f"Failed to parse OAuth 2.1 discovery document: {e}"
+ )
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}",
+ )
- url = get_tool_server_url(form_data.url, form_data.path)
- return await get_tool_server_data(token, url)
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}",
+ )
+ else:
+ try:
+ client = MCPClient()
+ headers = None
+
+ token = None
+ if form_data.auth_type == "bearer":
+ token = form_data.key
+ elif form_data.auth_type == "session":
+ token = request.state.token.credentials
+ elif form_data.auth_type == "system_oauth":
+ try:
+ if request.cookies.get("oauth_session_id", None):
+ token = await request.app.state.oauth_manager.get_oauth_token(
+ user.id,
+ request.cookies.get("oauth_session_id", None),
+ )
+ except Exception as e:
+ pass
+
+ if token:
+ headers = {"Authorization": f"Bearer {token}"}
+
+ await client.connect(form_data.url, headers=headers)
+ specs = await client.list_tool_specs()
+ return {
+ "status": True,
+ "specs": specs,
+ }
+ except Exception as e:
+ log.debug(f"Failed to create MCP client: {e}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to create MCP client",
+ )
+ finally:
+ if client:
+ await client.disconnect()
+ else: # openapi
+ token = None
+ if form_data.auth_type == "bearer":
+ token = form_data.key
+ elif form_data.auth_type == "session":
+ token = request.state.token.credentials
+ elif form_data.auth_type == "system_oauth":
+ try:
+ if request.cookies.get("oauth_session_id", None):
+ token = await request.app.state.oauth_manager.get_oauth_token(
+ user.id,
+ request.cookies.get("oauth_session_id", None),
+ )
+ except Exception as e:
+ pass
+
+ url = get_tool_server_url(form_data.url, form_data.path)
+ return await get_tool_server_data(token, url)
+ except HTTPException as e:
+ raise e
except Exception as e:
+ log.debug(f"Failed to connect to the tool server: {e}")
raise HTTPException(
status_code=400,
- detail=f"Failed to connect to the tool server: {str(e)}",
+ detail=f"Failed to connect to the tool server",
)
diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py
index 36dbfee5c5..ddee71ea4d 100644
--- a/backend/open_webui/routers/folders.py
+++ b/backend/open_webui/routers/folders.py
@@ -262,15 +262,15 @@ async def update_folder_is_expanded_by_id(
async def delete_folder_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
- chat_delete_permission = has_permission(
- user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
- )
-
- if user.role != "admin" and not chat_delete_permission:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+ if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
+ chat_delete_permission = has_permission(
+ user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
)
+ if user.role != "admin" and not chat_delete_permission:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+ )
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py
index 202aa74ca4..c36e656d5f 100644
--- a/backend/open_webui/routers/functions.py
+++ b/backend/open_webui/routers/functions.py
@@ -431,8 +431,10 @@ async def update_function_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
- Functions.update_function_valves_by_id(id, valves.model_dump())
- return valves.model_dump()
+
+ valves_dict = valves.model_dump(exclude_unset=True)
+ Functions.update_function_valves_by_id(id, valves_dict)
+ return valves_dict
except Exception as e:
log.exception(f"Error updating function values by id {id}: {e}")
raise HTTPException(
@@ -514,10 +516,11 @@ async def update_function_user_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
+ user_valves_dict = user_valves.model_dump(exclude_unset=True)
Functions.update_user_valves_by_id_and_user_id(
- id, user.id, user_valves.model_dump()
+ id, user.id, user_valves_dict
)
- return user_valves.model_dump()
+ return user_valves_dict
except Exception as e:
log.exception(f"Error updating function user valves by id {id}: {e}")
raise HTTPException(
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index 802a3e9924..059b3a23d7 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -514,6 +514,7 @@ async def image_generations(
size = form_data.size
width, height = tuple(map(int, size.split("x")))
+ model = get_image_model(request)
r = None
try:
@@ -531,11 +532,7 @@ async def image_generations(
headers["X-OpenWebUI-User-Role"] = user.role
data = {
- "model": (
- request.app.state.config.IMAGE_GENERATION_MODEL
- if request.app.state.config.IMAGE_GENERATION_MODEL != ""
- else "dall-e-2"
- ),
+ "model": model,
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
@@ -584,7 +581,6 @@ async def image_generations(
headers["Content-Type"] = "application/json"
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
- model = get_image_model(request)
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
@@ -640,7 +636,7 @@ async def image_generations(
}
)
res = await comfyui_generate_image(
- request.app.state.config.IMAGE_GENERATION_MODEL,
+ model,
form_data,
user.id,
request.app.state.config.COMFYUI_BASE_URL,
diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py
index dff7bc2e7f..0c420e4f12 100644
--- a/backend/open_webui/routers/notes.py
+++ b/backend/open_webui/routers/notes.py
@@ -48,7 +48,7 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
}
)
- for note in Notes.get_notes_by_user_id(user.id, "write")
+ for note in Notes.get_notes_by_permission(user.id, "write")
]
return notes
@@ -81,7 +81,9 @@ async def get_note_list(
notes = [
NoteTitleIdResponse(**note.model_dump())
- for note in Notes.get_notes_by_user_id(user.id, "write", skip=skip, limit=limit)
+ for note in Notes.get_notes_by_permission(
+ user.id, "write", skip=skip, limit=limit
+ )
]
return notes
diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py
index 8dadf3523a..bf11ffa0dd 100644
--- a/backend/open_webui/routers/ollama.py
+++ b/backend/open_webui/routers/ollama.py
@@ -1694,25 +1694,27 @@ async def download_file_stream(
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
- file.seek(0)
- chunk_size = 1024 * 1024 * 2
- hashed = calculate_sha256(file, chunk_size)
- file.seek(0)
+ file.close()
- url = f"{ollama_url}/api/blobs/sha256:{hashed}"
- response = requests.post(url, data=file)
+ with open(file_path, "rb") as file:
+ chunk_size = 1024 * 1024 * 2
+ hashed = calculate_sha256(file, chunk_size)
- if response.ok:
- res = {
- "done": done,
- "blob": f"sha256:{hashed}",
- "name": file_name,
- }
- os.remove(file_path)
+ url = f"{ollama_url}/api/blobs/sha256:{hashed}"
+ with requests.Session() as session:
+ response = session.post(url, data=file, timeout=30)
- yield f"data: {json.dumps(res)}\n\n"
- else:
- raise "Ollama: Could not create blob, Please try again."
+ if response.ok:
+ res = {
+ "done": done,
+ "blob": f"sha256:{hashed}",
+ "name": file_name,
+ }
+ os.remove(file_path)
+
+ yield f"data: {json.dumps(res)}\n\n"
+ else:
+ raise "Ollama: Could not create blob, Please try again."
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index 3154be2ee6..e8865b90a0 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -121,7 +121,7 @@ def openai_reasoning_model_handler(payload):
return payload
-def get_headers_and_cookies(
+async def get_headers_and_cookies(
request: Request,
url,
key=None,
@@ -174,7 +174,7 @@ def get_headers_and_cookies(
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
- oauth_token = request.app.state.oauth_manager.get_oauth_token(
+ oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
@@ -305,7 +305,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
)
- headers, cookies = get_headers_and_cookies(
+ headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user
)
@@ -570,7 +570,7 @@ async def get_models(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
- headers, cookies = get_headers_and_cookies(
+ headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user
)
@@ -656,7 +656,7 @@ async def verify_connection(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
- headers, cookies = get_headers_and_cookies(
+ headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user
)
@@ -901,7 +901,7 @@ async def generate_chat_completion(
convert_logit_bias_input_to_json(payload["logit_bias"])
)
- headers, cookies = get_headers_and_cookies(
+ headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, metadata, user=user
)
@@ -1010,7 +1010,9 @@ async def embeddings(request: Request, form_data: dict, user):
session = None
streaming = False
- headers, cookies = get_headers_and_cookies(request, url, key, api_config, user=user)
+ headers, cookies = await get_headers_and_cookies(
+ request, url, key, api_config, user=user
+ )
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
@@ -1080,7 +1082,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
streaming = False
try:
- headers, cookies = get_headers_and_cookies(
+ headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, user=user
)
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index 0ddf824efa..3681008c87 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -45,6 +45,8 @@ from open_webui.retrieval.loaders.youtube import YoutubeLoader
# Web search engines
from open_webui.retrieval.web.main import SearchResult
from open_webui.retrieval.web.utils import get_web_loader
+from open_webui.retrieval.web.ollama import search_ollama_cloud
+from open_webui.retrieval.web.perplexity_search import search_perplexity_search
from open_webui.retrieval.web.brave import search_brave
from open_webui.retrieval.web.kagi import search_kagi
from open_webui.retrieval.web.mojeek import search_mojeek
@@ -469,6 +471,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
+ "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
@@ -525,6 +528,7 @@ class WebConfig(BaseModel):
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
+ OLLAMA_CLOUD_WEB_SEARCH_API_KEY: Optional[str] = None
SEARXNG_QUERY_URL: Optional[str] = None
YACY_QUERY_URL: Optional[str] = None
YACY_USERNAME: Optional[str] = None
@@ -988,6 +992,9 @@ async def update_rag_config(
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
)
+ request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = (
+ form_data.web.OLLAMA_CLOUD_WEB_SEARCH_API_KEY
+ )
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
@@ -1139,6 +1146,7 @@ async def update_rag_config(
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
+ "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
@@ -1407,59 +1415,35 @@ def process_file(
form_data: ProcessFileForm,
user=Depends(get_verified_user),
):
- try:
+ if user.role == "admin":
file = Files.get_file_by_id(form_data.file_id)
+ else:
+ file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id)
- collection_name = form_data.collection_name
+ if file:
+ try:
- if collection_name is None:
- collection_name = f"file-{file.id}"
+ collection_name = form_data.collection_name
- if form_data.content:
- # Update the content in the file
- # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
+ if collection_name is None:
+ collection_name = f"file-{file.id}"
- try:
- # /files/{file_id}/data/content/update
- VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
- except:
- # Audio file upload pipeline
- pass
+ if form_data.content:
+ # Update the content in the file
+ # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
- docs = [
- Document(
- page_content=form_data.content.replace("
", "\n"),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
-
- text_content = form_data.content
- elif form_data.collection_name:
- # Check if the file has already been processed and save the content
- # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
-
- result = VECTOR_DB_CLIENT.query(
- collection_name=f"file-{file.id}", filter={"file_id": file.id}
- )
-
- if result is not None and len(result.ids[0]) > 0:
- docs = [
- Document(
- page_content=result.documents[0][idx],
- metadata=result.metadatas[0][idx],
+ try:
+ # /files/{file_id}/data/content/update
+ VECTOR_DB_CLIENT.delete_collection(
+ collection_name=f"file-{file.id}"
)
- for idx, id in enumerate(result.ids[0])
- ]
- else:
+ except:
+ # Audio file upload pipeline
+ pass
+
docs = [
Document(
- page_content=file.data.get("content", ""),
+ page_content=form_data.content.replace("
", "\n"),
metadata={
**file.meta,
"name": file.filename,
@@ -1470,149 +1454,190 @@ def process_file(
)
]
- text_content = file.data.get("content", "")
- else:
- # Process the file and save the content
- # Usage: /files/
- file_path = file.path
- if file_path:
- file_path = Storage.get_file(file_path)
- loader = Loader(
- engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
- DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
- DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
- DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
- DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
- DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
- DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
- DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
- DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
- DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
- DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
- DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
- EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
- EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
- TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
- DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
- DOCLING_PARAMS={
- "do_ocr": request.app.state.config.DOCLING_DO_OCR,
- "force_ocr": request.app.state.config.DOCLING_FORCE_OCR,
- "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
- "ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
- "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND,
- "table_mode": request.app.state.config.DOCLING_TABLE_MODE,
- "pipeline": request.app.state.config.DOCLING_PIPELINE,
- "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
- "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
- "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
- "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
- },
- PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
- DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
- DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
- MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
- )
- docs = loader.load(
- file.filename, file.meta.get("content_type"), file_path
+ text_content = form_data.content
+ elif form_data.collection_name:
+ # Check if the file has already been processed and save the content
+ # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
+
+ result = VECTOR_DB_CLIENT.query(
+ collection_name=f"file-{file.id}", filter={"file_id": file.id}
)
- docs = [
- Document(
- page_content=doc.page_content,
- metadata={
- **doc.metadata,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- for doc in docs
- ]
- else:
- docs = [
- Document(
- page_content=file.data.get("content", ""),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
- text_content = " ".join([doc.page_content for doc in docs])
-
- log.debug(f"text_content: {text_content}")
- Files.update_file_data_by_id(
- file.id,
- {"content": text_content},
- )
- hash = calculate_sha256_string(text_content)
- Files.update_file_hash_by_id(file.id, hash)
-
- if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
- Files.update_file_data_by_id(file.id, {"status": "completed"})
- return {
- "status": True,
- "collection_name": None,
- "filename": file.filename,
- "content": text_content,
- }
- else:
- try:
- result = save_docs_to_vector_db(
- request,
- docs=docs,
- collection_name=collection_name,
- metadata={
- "file_id": file.id,
- "name": file.filename,
- "hash": hash,
- },
- add=(True if form_data.collection_name else False),
- user=user,
- )
- log.info(f"added {len(docs)} items to collection {collection_name}")
-
- if result:
- Files.update_file_metadata_by_id(
- file.id,
- {
- "collection_name": collection_name,
- },
- )
-
- Files.update_file_data_by_id(
- file.id,
- {"status": "completed"},
- )
-
- return {
- "status": True,
- "collection_name": collection_name,
- "filename": file.filename,
- "content": text_content,
- }
+ if result is not None and len(result.ids[0]) > 0:
+ docs = [
+ Document(
+ page_content=result.documents[0][idx],
+ metadata=result.metadatas[0][idx],
+ )
+ for idx, id in enumerate(result.ids[0])
+ ]
else:
- raise Exception("Error saving document to vector database")
- except Exception as e:
- raise e
+ docs = [
+ Document(
+ page_content=file.data.get("content", ""),
+ metadata={
+ **file.meta,
+ "name": file.filename,
+ "created_by": file.user_id,
+ "file_id": file.id,
+ "source": file.filename,
+ },
+ )
+ ]
- except Exception as e:
- log.exception(e)
- if "No pandoc was found" in str(e):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
+ text_content = file.data.get("content", "")
+ else:
+ # Process the file and save the content
+ # Usage: /files/
+ file_path = file.path
+ if file_path:
+ file_path = Storage.get_file(file_path)
+ loader = Loader(
+ engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
+ DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
+ DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
+ DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
+ DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
+ DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
+ DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
+ DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
+ DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
+ DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
+ DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
+ DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
+ EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
+ EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
+ TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
+ DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
+ DOCLING_PARAMS={
+ "do_ocr": request.app.state.config.DOCLING_DO_OCR,
+ "force_ocr": request.app.state.config.DOCLING_FORCE_OCR,
+ "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
+ "ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
+ "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND,
+ "table_mode": request.app.state.config.DOCLING_TABLE_MODE,
+ "pipeline": request.app.state.config.DOCLING_PIPELINE,
+ "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
+ "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
+ "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
+ "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
+ },
+ PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
+ DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
+ DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
+ MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
+ )
+ docs = loader.load(
+ file.filename, file.meta.get("content_type"), file_path
+ )
+
+ docs = [
+ Document(
+ page_content=doc.page_content,
+ metadata={
+ **doc.metadata,
+ "name": file.filename,
+ "created_by": file.user_id,
+ "file_id": file.id,
+ "source": file.filename,
+ },
+ )
+ for doc in docs
+ ]
+ else:
+ docs = [
+ Document(
+ page_content=file.data.get("content", ""),
+ metadata={
+ **file.meta,
+ "name": file.filename,
+ "created_by": file.user_id,
+ "file_id": file.id,
+ "source": file.filename,
+ },
+ )
+ ]
+ text_content = " ".join([doc.page_content for doc in docs])
+
+ log.debug(f"text_content: {text_content}")
+ Files.update_file_data_by_id(
+ file.id,
+ {"content": text_content},
)
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e),
+ hash = calculate_sha256_string(text_content)
+ Files.update_file_hash_by_id(file.id, hash)
+
+ if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
+ Files.update_file_data_by_id(file.id, {"status": "completed"})
+ return {
+ "status": True,
+ "collection_name": None,
+ "filename": file.filename,
+ "content": text_content,
+ }
+ else:
+ try:
+ result = save_docs_to_vector_db(
+ request,
+ docs=docs,
+ collection_name=collection_name,
+ metadata={
+ "file_id": file.id,
+ "name": file.filename,
+ "hash": hash,
+ },
+ add=(True if form_data.collection_name else False),
+ user=user,
+ )
+ log.info(f"added {len(docs)} items to collection {collection_name}")
+
+ if result:
+ Files.update_file_metadata_by_id(
+ file.id,
+ {
+ "collection_name": collection_name,
+ },
+ )
+
+ Files.update_file_data_by_id(
+ file.id,
+ {"status": "completed"},
+ )
+
+ return {
+ "status": True,
+ "collection_name": collection_name,
+ "filename": file.filename,
+ "content": text_content,
+ }
+ else:
+ raise Exception("Error saving document to vector database")
+ except Exception as e:
+ raise e
+
+ except Exception as e:
+ log.exception(e)
+ Files.update_file_data_by_id(
+ file.id,
+ {"status": "failed"},
)
+ if "No pandoc was found" in str(e):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=str(e),
+ )
+
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+ )
+
class ProcessTextForm(BaseModel):
name: str
@@ -1769,7 +1794,25 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
"""
# TODO: add playwright to search the web
- if engine == "searxng":
+ if engine == "ollama_cloud":
+ return search_ollama_cloud(
+ "https://ollama.com",
+ request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ elif engine == "perplexity_search":
+ if request.app.state.config.PERPLEXITY_API_KEY:
+ return search_perplexity_search(
+ request.app.state.config.PERPLEXITY_API_KEY,
+ query,
+ request.app.state.config.WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ else:
+ raise Exception("No PERPLEXITY_API_KEY found in environment variables")
+ elif engine == "searxng":
if request.app.state.config.SEARXNG_QUERY_URL:
return search_searxng(
request.app.state.config.SEARXNG_QUERY_URL,
diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py
index 5f82e7f1bd..eb66a86825 100644
--- a/backend/open_webui/routers/tools.py
+++ b/backend/open_webui/routers/tools.py
@@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl
from fastapi import APIRouter, Depends, HTTPException, Request, status
+from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.tools import (
ToolForm,
ToolModel,
@@ -41,8 +42,17 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(request: Request, user=Depends(get_verified_user)):
- tools = Tools.get_tools()
+ tools = [
+ ToolUserResponse(
+ **{
+ **tool.model_dump(),
+ "has_user_valves": "class UserValves(BaseModel):" in tool.content,
+ }
+ )
+ for tool in Tools.get_tools()
+ ]
+ # OpenAPI Tool Servers
for server in await get_tool_servers(request):
tools.append(
ToolUserResponse(
@@ -68,6 +78,50 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
)
)
+ # MCP Tool Servers
+ for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+ if server.get("type", "openapi") == "mcp":
+ server_id = server.get("info", {}).get("id")
+ auth_type = server.get("auth_type", "none")
+
+ session_token = None
+ if auth_type == "oauth_2.1":
+ splits = server_id.split(":")
+ server_id = splits[-1] if len(splits) > 1 else server_id
+
+ session_token = (
+ await request.app.state.oauth_client_manager.get_oauth_token(
+ user.id, f"mcp:{server_id}"
+ )
+ )
+
+ tools.append(
+ ToolUserResponse(
+ **{
+ "id": f"server:mcp:{server.get('info', {}).get('id')}",
+ "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
+ "name": server.get("info", {}).get("name", "MCP Tool Server"),
+ "meta": {
+ "description": server.get("info", {}).get(
+ "description", ""
+ ),
+ },
+ "access_control": server.get("config", {}).get(
+ "access_control", None
+ ),
+ "updated_at": int(time.time()),
+ "created_at": int(time.time()),
+ **(
+ {
+ "authenticated": session_token is not None,
+ }
+ if auth_type == "oauth_2.1"
+ else {}
+ ),
+ }
+ )
+ )
+
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
# Admin can see all tools
return tools
@@ -462,8 +516,9 @@ async def update_tools_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
- Tools.update_tool_valves_by_id(id, valves.model_dump())
- return valves.model_dump()
+ valves_dict = valves.model_dump(exclude_unset=True)
+ Tools.update_tool_valves_by_id(id, valves_dict)
+ return valves_dict
except Exception as e:
log.exception(f"Failed to update tool valves by id {id}: {e}")
raise HTTPException(
@@ -538,10 +593,11 @@ async def update_tools_user_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
+ user_valves_dict = user_valves.model_dump(exclude_unset=True)
Tools.update_user_valves_by_id_and_user_id(
- id, user.id, user_valves.model_dump()
+ id, user.id, user_valves_dict
)
- return user_valves.model_dump()
+ return user_valves_dict
except Exception as e:
log.exception(f"Failed to update user valves by id {id}: {e}")
raise HTTPException(
diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py
index 6215a6ac22..af48bebfb4 100644
--- a/backend/open_webui/utils/access_control.py
+++ b/backend/open_webui/utils/access_control.py
@@ -110,9 +110,13 @@ def has_access(
type: str = "write",
access_control: Optional[dict] = None,
user_group_ids: Optional[Set[str]] = None,
+ strict: bool = True,
) -> bool:
if access_control is None:
- return type == "read"
+ if strict:
+ return type == "read"
+ else:
+ return True
if user_group_ids is None:
user_groups = Groups.get_groups_by_member_id(user_id)
diff --git a/backend/open_webui/utils/files.py b/backend/open_webui/utils/files.py
new file mode 100644
index 0000000000..b410cbab50
--- /dev/null
+++ b/backend/open_webui/utils/files.py
@@ -0,0 +1,97 @@
+from open_webui.routers.images import (
+ load_b64_image_data,
+ upload_image,
+)
+
+from fastapi import (
+ APIRouter,
+ Depends,
+ HTTPException,
+ Request,
+ UploadFile,
+)
+
+from open_webui.routers.files import upload_file_handler
+
+import mimetypes
+import base64
+import io
+
+
+def get_image_url_from_base64(request, base64_image_string, metadata, user):
+ if "data:image/png;base64" in base64_image_string:
+ image_url = ""
+ # Extract base64 image data from the line
+ image_data, content_type = load_b64_image_data(base64_image_string)
+ if image_data is not None:
+ image_url = upload_image(
+ request,
+ image_data,
+ content_type,
+ metadata,
+ user,
+ )
+ return image_url
+ return None
+
+
+def load_b64_audio_data(b64_str):
+ try:
+ if "," in b64_str:
+ header, b64_data = b64_str.split(",", 1)
+ else:
+ b64_data = b64_str
+ header = "data:audio/wav;base64"
+ audio_data = base64.b64decode(b64_data)
+ content_type = (
+ header.split(";")[0].split(":")[1] if ";" in header else "audio/wav"
+ )
+ return audio_data, content_type
+ except Exception as e:
+ print(f"Error decoding base64 audio data: {e}")
+ return None, None
+
+
+def upload_audio(request, audio_data, content_type, metadata, user):
+ audio_format = mimetypes.guess_extension(content_type)
+ file = UploadFile(
+ file=io.BytesIO(audio_data),
+ filename=f"generated-{audio_format}", # will be converted to a unique ID on upload_file
+ headers={
+ "content-type": content_type,
+ },
+ )
+ file_item = upload_file_handler(
+ request,
+ file=file,
+ metadata=metadata,
+ process=False,
+ user=user,
+ )
+ url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
+ return url
+
+
+def get_audio_url_from_base64(request, base64_audio_string, metadata, user):
+ if "data:audio/wav;base64" in base64_audio_string:
+ audio_url = ""
+ # Extract base64 audio data from the line
+ audio_data, content_type = load_b64_audio_data(base64_audio_string)
+ if audio_data is not None:
+ audio_url = upload_audio(
+ request,
+ audio_data,
+ content_type,
+ metadata,
+ user,
+ )
+ return audio_url
+ return None
+
+
+def get_file_url_from_base64(request, base64_file_string, metadata, user):
+ if "data:image/png;base64" in base64_file_string:
+ return get_image_url_from_base64(request, base64_file_string, metadata, user)
+ elif "data:audio/wav;base64" in base64_file_string:
+ return get_audio_url_from_base64(request, base64_file_string, metadata, user)
+ return None
diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py
new file mode 100644
index 0000000000..01df38886c
--- /dev/null
+++ b/backend/open_webui/utils/mcp/client.py
@@ -0,0 +1,110 @@
+import asyncio
+from typing import Optional
+from contextlib import AsyncExitStack
+
+from mcp import ClientSession
+from mcp.client.auth import OAuthClientProvider, TokenStorage
+from mcp.client.streamable_http import streamablehttp_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+
+
+class MCPClient:
+ def __init__(self):
+ self.session: Optional[ClientSession] = None
+ self.exit_stack = AsyncExitStack()
+
+ async def connect(self, url: str, headers: Optional[dict] = None):
+ try:
+ self._streams_context = streamablehttp_client(url, headers=headers)
+
+ transport = await self.exit_stack.enter_async_context(self._streams_context)
+ read_stream, write_stream, _ = transport
+
+ self._session_context = ClientSession(
+ read_stream, write_stream
+ ) # pylint: disable=W0201
+
+ self.session = await self.exit_stack.enter_async_context(
+ self._session_context
+ )
+ await self.session.initialize()
+ except Exception as e:
+ await self.disconnect()
+ raise e
+
+ async def list_tool_specs(self) -> Optional[dict]:
+ if not self.session:
+ raise RuntimeError("MCP client is not connected.")
+
+ result = await self.session.list_tools()
+ tools = result.tools
+
+ tool_specs = []
+ for tool in tools:
+ name = tool.name
+ description = tool.description
+
+ inputSchema = tool.inputSchema
+
+ # TODO: handle outputSchema if needed
+ outputSchema = getattr(tool, "outputSchema", None)
+
+ tool_specs.append(
+ {"name": name, "description": description, "parameters": inputSchema}
+ )
+
+ return tool_specs
+
+ async def call_tool(
+ self, function_name: str, function_args: dict
+ ) -> Optional[dict]:
+ if not self.session:
+ raise RuntimeError("MCP client is not connected.")
+
+ result = await self.session.call_tool(function_name, function_args)
+ if not result:
+ raise Exception("No result returned from MCP tool call.")
+
+ result_dict = result.model_dump(mode="json")
+ result_content = result_dict.get("content", {})
+
+ if result.isError:
+ raise Exception(result_content)
+ else:
+ return result_content
+
+ async def list_resources(self, cursor: Optional[str] = None) -> Optional[dict]:
+ if not self.session:
+ raise RuntimeError("MCP client is not connected.")
+
+ result = await self.session.list_resources(cursor=cursor)
+ if not result:
+ raise Exception("No result returned from MCP list_resources call.")
+
+ result_dict = result.model_dump()
+ resources = result_dict.get("resources", [])
+
+ return resources
+
+ async def read_resource(self, uri: str) -> Optional[dict]:
+ if not self.session:
+ raise RuntimeError("MCP client is not connected.")
+
+ result = await self.session.read_resource(uri)
+ if not result:
+ raise Exception("No result returned from MCP read_resource call.")
+ result_dict = result.model_dump()
+
+ return result_dict
+
+ async def disconnect(self):
+ # Clean up and close the session
+ await self.exit_stack.aclose()
+
+ async def __aenter__(self):
+ await self.exit_stack.__aenter__()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.exit_stack.__aexit__(exc_type, exc_value, traceback)
+ await self.disconnect()
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 97f19dcded..ff8c215607 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -20,9 +20,11 @@ from concurrent.futures import ThreadPoolExecutor
from fastapi import Request, HTTPException
+from fastapi.responses import HTMLResponse
from starlette.responses import Response, StreamingResponse, JSONResponse
+from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.chats import Chats
from open_webui.models.folders import Folders
from open_webui.models.users import Users
@@ -52,6 +54,11 @@ from open_webui.routers.pipelines import (
from open_webui.routers.memories import query_memory, QueryMemoryForm
from open_webui.utils.webhook import post_webhook
+from open_webui.utils.files import (
+ get_audio_url_from_base64,
+ get_file_url_from_base64,
+ get_image_url_from_base64,
+)
from open_webui.models.users import UserModel
@@ -86,6 +93,7 @@ from open_webui.utils.filter import (
)
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.utils.payload import apply_system_prompt_to_body
+from open_webui.utils.mcp.client import MCPClient
from open_webui.config import (
@@ -144,12 +152,14 @@ async def chat_completion_tools_handler(
def get_tools_function_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
- history = "\n".join(
+
+ recent_messages = messages[-4:] if len(messages) > 4 else messages
+ chat_history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
- for message in messages[::-1][:4]
+ for message in recent_messages
)
- prompt = f"History:\n{history}\nQuery: {user_message}"
+ prompt = f"History:\n{chat_history}\nQuery: {user_message}"
return {
"model": task_model_id,
@@ -631,48 +641,53 @@ async def chat_completion_files_handler(
sources = []
if files := body.get("metadata", {}).get("files", None):
+ # Check if all files are in full context mode
+ all_full_context = all(item.get("context") == "full" for item in files)
+
queries = []
- try:
- queries_response = await generate_queries(
- request,
- {
- "model": body["model"],
- "messages": body["messages"],
- "type": "retrieval",
- },
- user,
- )
- queries_response = queries_response["choices"][0]["message"]["content"]
-
+ if not all_full_context:
try:
- bracket_start = queries_response.find("{")
- bracket_end = queries_response.rfind("}") + 1
+ queries_response = await generate_queries(
+ request,
+ {
+ "model": body["model"],
+ "messages": body["messages"],
+ "type": "retrieval",
+ },
+ user,
+ )
+ queries_response = queries_response["choices"][0]["message"]["content"]
- if bracket_start == -1 or bracket_end == -1:
- raise Exception("No JSON object found in the response")
+ try:
+ bracket_start = queries_response.find("{")
+ bracket_end = queries_response.rfind("}") + 1
- queries_response = queries_response[bracket_start:bracket_end]
- queries_response = json.loads(queries_response)
- except Exception as e:
- queries_response = {"queries": [queries_response]}
+ if bracket_start == -1 or bracket_end == -1:
+ raise Exception("No JSON object found in the response")
- queries = queries_response.get("queries", [])
- except:
- pass
+ queries_response = queries_response[bracket_start:bracket_end]
+ queries_response = json.loads(queries_response)
+ except Exception as e:
+ queries_response = {"queries": [queries_response]}
+
+ queries = queries_response.get("queries", [])
+ except:
+ pass
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
- await __event_emitter__(
- {
- "type": "status",
- "data": {
- "action": "queries_generated",
- "queries": queries,
- "done": False,
- },
- }
- )
+ if not all_full_context:
+ await __event_emitter__(
+ {
+ "type": "status",
+ "data": {
+ "action": "queries_generated",
+ "queries": queries,
+ "done": False,
+ },
+ }
+ )
try:
# Offload get_sources_from_items to a separate thread
@@ -701,7 +716,8 @@ async def chat_completion_files_handler(
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
- full_context=request.app.state.config.RAG_FULL_CONTEXT,
+ full_context=all_full_context
+ or request.app.state.config.RAG_FULL_CONTEXT,
user=user,
),
)
@@ -818,7 +834,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
- oauth_token = request.app.state.oauth_manager.get_oauth_token(
+ oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
@@ -987,14 +1003,107 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
- tool_servers = metadata.get("tool_servers", None)
+ direct_tool_servers = metadata.get("tool_servers", None)
log.debug(f"{tool_ids=}")
- log.debug(f"{tool_servers=}")
+ log.debug(f"{direct_tool_servers=}")
tools_dict = {}
+ mcp_clients = []
+ mcp_tools_dict = {}
+
if tool_ids:
+ for tool_id in tool_ids:
+ if tool_id.startswith("server:mcp:"):
+ try:
+ server_id = tool_id[len("server:mcp:") :]
+
+ mcp_server_connection = None
+ for (
+ server_connection
+ ) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+ if (
+ server_connection.get("type", "") == "mcp"
+ and server_connection.get("info", {}).get("id") == server_id
+ ):
+ mcp_server_connection = server_connection
+ break
+
+ if not mcp_server_connection:
+ log.error(f"MCP server with id {server_id} not found")
+ continue
+
+ auth_type = mcp_server_connection.get("auth_type", "")
+
+ headers = {}
+ if auth_type == "bearer":
+ headers["Authorization"] = (
+ f"Bearer {mcp_server_connection.get('key', '')}"
+ )
+ elif auth_type == "none":
+ # No authentication
+ pass
+ elif auth_type == "session":
+ headers["Authorization"] = (
+ f"Bearer {request.state.token.credentials}"
+ )
+ elif auth_type == "system_oauth":
+ oauth_token = extra_params.get("__oauth_token__", None)
+ if oauth_token:
+ headers["Authorization"] = (
+ f"Bearer {oauth_token.get('access_token', '')}"
+ )
+ elif auth_type == "oauth_2.1":
+ try:
+ splits = server_id.split(":")
+ server_id = splits[-1] if len(splits) > 1 else server_id
+
+ oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
+ user.id, f"mcp:{server_id}"
+ )
+
+ if oauth_token:
+ headers["Authorization"] = (
+ f"Bearer {oauth_token.get('access_token', '')}"
+ )
+ except Exception as e:
+ log.error(f"Error getting OAuth token: {e}")
+ oauth_token = None
+
+ mcp_client = MCPClient()
+ await mcp_client.connect(
+ url=mcp_server_connection.get("url", ""),
+ headers=headers if headers else None,
+ )
+
+ tool_specs = await mcp_client.list_tool_specs()
+ for tool_spec in tool_specs:
+
+ def make_tool_function(function_name):
+ async def tool_function(**kwargs):
+ return await mcp_client.call_tool(
+ function_name,
+ function_args=kwargs,
+ )
+
+ return tool_function
+
+ tool_function = make_tool_function(tool_spec["name"])
+
+ mcp_tools_dict[tool_spec["name"]] = {
+ "spec": tool_spec,
+ "callable": tool_function,
+ "type": "mcp",
+ "client": mcp_client,
+ "direct": False,
+ }
+
+ mcp_clients.append(mcp_client)
+ except Exception as e:
+ log.debug(e)
+ continue
+
tools_dict = await get_tools(
request,
tool_ids,
@@ -1006,9 +1115,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
+ if mcp_tools_dict:
+ tools_dict = {**tools_dict, **mcp_tools_dict}
- if tool_servers:
- for tool_server in tool_servers:
+ if direct_tool_servers:
+ for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", [])
for tool in tool_specs:
@@ -1018,6 +1129,9 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"server": tool_server,
}
+ if mcp_clients:
+ metadata["mcp_clients"] = mcp_clients
+
if tools_dict:
if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
@@ -1026,6 +1140,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
{"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values()
]
+
else:
# If the function calling is not native, then call the tools function calling handler
try:
@@ -1079,26 +1194,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise Exception("No user message found")
if context_string != "":
- # Workaround for Ollama 2.0+ system prompt issue
- # TODO: replace with add_or_update_system_message
- if model.get("owned_by") == "ollama":
- form_data["messages"] = prepend_to_first_user_message_content(
- rag_template(
- request.app.state.config.RAG_TEMPLATE,
- context_string,
- prompt,
- ),
- form_data["messages"],
- )
- else:
- form_data["messages"] = add_or_update_system_message(
- rag_template(
- request.app.state.config.RAG_TEMPLATE,
- context_string,
- prompt,
- ),
- form_data["messages"],
- )
+ form_data["messages"] = add_or_update_user_message(
+ rag_template(
+ request.app.state.config.RAG_TEMPLATE,
+ context_string,
+ prompt,
+ ),
+ form_data["messages"],
+ append=False,
+ )
# If there are citations, add them to the data_items
sources = [
@@ -1498,7 +1602,7 @@ async def process_chat_response(
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
- oauth_token = request.app.state.oauth_manager.get_oauth_token(
+ oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
@@ -1581,7 +1685,8 @@ async def process_chat_response(
break
if tool_result is not None:
- tool_calls_display_content = f'{tool_calls_display_content}Tool Executed
\nTool Executed
\nExecuting...
\n