diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 2f5f34019d..7e5c35a451 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2767,6 +2767,12 @@ WEB_SEARCH_TRUST_ENV = PersistentConfig( ) +OLLAMA_CLOUD_WEB_SEARCH_API_KEY = PersistentConfig( + "OLLAMA_CLOUD_WEB_SEARCH_API_KEY", + "rag.web.search.ollama_cloud_api_key", + os.getenv("OLLAMA_CLOUD_API_KEY", ""), +) + SEARXNG_QUERY_URL = PersistentConfig( "SEARXNG_QUERY_URL", "rag.web.search.searxng_query_url", diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 243b8212a8..e02424f969 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -474,6 +474,10 @@ ENABLE_OAUTH_ID_TOKEN_COOKIE = ( os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true" ) +OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get( + "OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY +) + OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get( "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY ) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7ebb76cb58..c6930f9b99 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -50,6 +50,11 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse from starlette.datastructures import Headers +from starsessions import ( + SessionMiddleware as StarSessionsMiddleware, + SessionAutoloadMiddleware, +) +from starsessions.stores.redis import RedisStore from open_webui.utils import logger from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware @@ -269,6 +274,7 @@ from open_webui.config import ( WEB_SEARCH_CONCURRENT_REQUESTS, WEB_SEARCH_TRUST_ENV, WEB_SEARCH_DOMAIN_FILTER_LIST, + OLLAMA_CLOUD_WEB_SEARCH_API_KEY, JINA_API_KEY, SEARCHAPI_API_KEY, SEARCHAPI_ENGINE, @@ -467,7 +473,12 @@ from open_webui.utils.auth import ( get_verified_user, ) from open_webui.utils.plugin import install_tool_and_function_dependencies -from open_webui.utils.oauth import OAuthManager +from open_webui.utils.oauth import ( + OAuthManager, + OAuthClientManager, + decrypt_data, + OAuthClientInformationFull, +) from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.redis import get_redis_connection @@ -597,9 +608,14 @@ app = FastAPI( lifespan=lifespan, ) +# For Open WebUI OIDC/OAuth2 oauth_manager = OAuthManager(app) app.state.oauth_manager = oauth_manager +# For Integrations +oauth_client_manager = OAuthClientManager(app) +app.state.oauth_client_manager = oauth_client_manager + app.state.instance_id = None app.state.config = AppConfig( redis_url=REDIS_URL, @@ -883,6 +899,8 @@ app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION + +app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = OLLAMA_CLOUD_WEB_SEARCH_API_KEY app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.YACY_QUERY_URL = YACY_QUERY_URL app.state.config.YACY_USERNAME = YACY_USERNAME @@ -1873,14 +1891,78 @@ async def get_current_usage(user=Depends(get_verified_user)): # OAuth Login & Callback ############################ + +# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1 +if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0: + for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS: + if tool_server_connection.get("type", "openapi") == "mcp": + server_id = tool_server_connection.get("info", {}).get("id") + auth_type = tool_server_connection.get("auth_type", "none") + if server_id and auth_type == "oauth_2.1": + oauth_client_info = tool_server_connection.get("info", {}).get( + "oauth_client_info", "" + ) + + oauth_client_info = decrypt_data(oauth_client_info) + app.state.oauth_client_manager.add_client( + f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info) + ) + + # SessionMiddleware is used by authlib for oauth if len(OAUTH_PROVIDERS) > 0: - app.add_middleware( - SessionMiddleware, - secret_key=WEBUI_SECRET_KEY, - session_cookie="oui-session", - same_site=WEBUI_SESSION_COOKIE_SAME_SITE, - https_only=WEBUI_SESSION_COOKIE_SECURE, + try: + if REDIS_URL: + redis_session_store = RedisStore( + url=REDIS_URL, + prefix=( + f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:" + ), + ) + + app.add_middleware(SessionAutoloadMiddleware) + app.add_middleware( + StarSessionsMiddleware, + store=redis_session_store, + cookie_name="oui-session", + cookie_same_site=WEBUI_SESSION_COOKIE_SAME_SITE, + cookie_https_only=WEBUI_SESSION_COOKIE_SECURE, + ) + log.info("Using Redis for session") + else: + raise ValueError("No Redis URL provided") + except Exception as e: + app.add_middleware( + SessionMiddleware, + secret_key=WEBUI_SECRET_KEY, + session_cookie="oui-session", + same_site=WEBUI_SESSION_COOKIE_SAME_SITE, + https_only=WEBUI_SESSION_COOKIE_SECURE, + ) + + +@app.get("/oauth/clients/{client_id}/authorize") +async def oauth_client_authorize( + client_id: str, + request: Request, + response: Response, + user=Depends(get_verified_user), +): + return await oauth_client_manager.handle_authorize(request, client_id=client_id) + + +@app.get("/oauth/clients/{client_id}/callback") +async def oauth_client_callback( + client_id: str, + request: Request, + response: Response, + user=Depends(get_verified_user), +): + return await oauth_client_manager.handle_callback( + request, + client_id=client_id, + user_id=user.id if user else None, + response=response, ) @@ -1895,8 +1977,9 @@ async def oauth_login(provider: str, request: Request): # - This is considered insecure in general, as OAuth providers do not always verify email addresses # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user # - Email addresses are considered unique, so we fail registration if the email address is already taken -@app.get("/oauth/{provider}/callback") -async def oauth_callback(provider: str, request: Request, response: Response): +@app.get("/oauth/{provider}/callback") # Legacy endpoint +@app.get("/oauth/{provider}/login/callback") +async def oauth_login_callback(provider: str, request: Request, response: Response): return await oauth_manager.handle_callback(request, provider, response) diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 92f238c3a0..e75266be78 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -57,6 +57,10 @@ class ChannelModel(BaseModel): #################### +class ChannelResponse(ChannelModel): + write_access: bool = False + + class ChannelForm(BaseModel): name: str description: Optional[str] = None diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index cadb5a3a79..97fd9b6256 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -492,11 +492,16 @@ class ChatTable: self, user_id: str, include_archived: bool = False, + include_folders: bool = False, skip: Optional[int] = None, limit: Optional[int] = None, ) -> list[ChatTitleIdResponse]: with get_db() as db: - query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) + query = db.query(Chat).filter_by(user_id=user_id) + + if not include_folders: + query = query.filter_by(folder_id=None) + query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) if not include_archived: @@ -943,6 +948,16 @@ class ChatTable: return count + def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + + query = query.filter_by(folder_id=folder_id) + count = query.count() + + log.info(f"Count of chats for folder '{folder_id}': {count}") + return count + def delete_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str ) -> bool: diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 57978225d4..bf07b5f86f 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -130,6 +130,17 @@ class FilesTable: except Exception: return None + def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]: + with get_db() as db: + try: + file = db.query(File).filter_by(id=id, user_id=user_id).first() + if file: + return FileModel.model_validate(file) + else: + return None + except Exception: + return None + def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: with get_db() as db: try: diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index 9fd5335ce5..81ce220384 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -176,6 +176,26 @@ class OAuthSessionTable: log.error(f"Error getting OAuth session by ID: {e}") return None + def get_session_by_provider_and_user_id( + self, provider: str, user_id: str + ) -> Optional[OAuthSessionModel]: + """Get OAuth session by provider and user ID""" + try: + with get_db() as db: + session = ( + db.query(OAuthSession) + .filter_by(provider=provider, user_id=user_id) + .first() + ) + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error getting OAuth session by provider and user ID: {e}") + return None + def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: """Get all OAuth sessions for a user""" try: diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index 3a47fa008d..48f84b3ac4 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -95,6 +95,8 @@ class ToolResponse(BaseModel): class ToolUserResponse(ToolResponse): user: Optional[UserResponse] = None + model_config = ConfigDict(extra="allow") + class ToolForm(BaseModel): id: str diff --git a/backend/open_webui/retrieval/web/ollama.py b/backend/open_webui/retrieval/web/ollama.py new file mode 100644 index 0000000000..a199a14389 --- /dev/null +++ b/backend/open_webui/retrieval/web/ollama.py @@ -0,0 +1,51 @@ +import logging +from dataclasses import dataclass +from typing import Optional + +import requests +from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.web.main import SearchResult + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_ollama_cloud( + url: str, + api_key: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, +) -> list[SearchResult]: + """Search using Ollama Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Ollama Search API key + query (str): The query to search for + count (int): Number of results to return + filter_list (Optional[list[str]]): List of domains to filter results by + """ + log.info(f"Searching with Ollama for query: {query}") + + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + payload = {"query": query, "max_results": count} + + try: + response = requests.post(f"{url}/api/web_search", headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + results = data.get("results", []) + log.info(f"Found {len(results)} results") + + return [ + SearchResult( + link=result.get("url", ""), + title=result.get("title", ""), + snippet=result.get("content", ""), + ) + for result in results + ] + except Exception as e: + log.error(f"Error searching Ollama: {e}") + return [] diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index da52be6e79..e7b8366347 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -10,7 +10,13 @@ from pydantic import BaseModel from open_webui.socket.main import sio, get_user_ids_from_room from open_webui.models.users import Users, UserNameResponse -from open_webui.models.channels import Channels, ChannelModel, ChannelForm +from open_webui.models.groups import Groups +from open_webui.models.channels import ( + Channels, + ChannelModel, + ChannelForm, + ChannelResponse, +) from open_webui.models.messages import ( Messages, MessageModel, @@ -80,7 +86,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user ############################ -@router.get("/{id}", response_model=Optional[ChannelModel]) +@router.get("/{id}", response_model=Optional[ChannelResponse]) async def get_channel_by_id(id: str, user=Depends(get_verified_user)): channel = Channels.get_channel_by_id(id) if not channel: @@ -95,7 +101,16 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - return ChannelModel(**channel.model_dump()) + write_access = has_access( + user.id, type="write", access_control=channel.access_control, strict=False + ) + + return ChannelResponse( + **{ + **channel.model_dump(), + "write_access": write_access or user.role == "admin", + } + ) ############################ @@ -275,6 +290,7 @@ async def model_response_handler(request, channel, message, user): ) thread_history = [] + images = [] message_users = {} for thread_message in thread_messages: @@ -303,6 +319,11 @@ async def model_response_handler(request, channel, message, user): f"{username}: {replace_mentions(thread_message.content)}" ) + thread_message_files = thread_message.data.get("files", []) + for file in thread_message_files: + if file.get("type", "") == "image": + images.append(file.get("url", "")) + system_message = { "role": "system", "content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational." @@ -313,14 +334,29 @@ async def model_response_handler(request, channel, message, user): ), } + content = f"{user.name if user else 'User'}: {message_content}" + if images: + content = [ + { + "type": "text", + "text": content, + }, + *[ + { + "type": "image_url", + "image_url": { + "url": image, + }, + } + for image in images + ], + ] + form_data = { "model": model_id, "messages": [ system_message, - { - "role": "user", - "content": f"{user.name if user else 'User'}: {message_content}", - }, + {"role": "user", "content": content}, ], "stream": False, } @@ -362,7 +398,7 @@ async def new_message_handler( ) if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="write", access_control=channel.access_control, strict=False ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -658,7 +694,7 @@ async def add_reaction_to_message( ) if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="write", access_control=channel.access_control, strict=False ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -724,7 +760,7 @@ async def remove_reaction_by_id_and_user_id_and_name( ) if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="write", access_control=channel.access_control, strict=False ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -806,7 +842,9 @@ async def delete_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access(user.id, type="read", access_control=channel.access_control) + and not has_access( + user.id, type="write", access_control=channel.access_control, strict=False + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 847368412e..788e355f2b 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -37,7 +37,9 @@ router = APIRouter() @router.get("/", response_model=list[ChatTitleIdResponse]) @router.get("/list", response_model=list[ChatTitleIdResponse]) def get_session_user_chat_list( - user=Depends(get_verified_user), page: Optional[int] = None + user=Depends(get_verified_user), + page: Optional[int] = None, + include_folders: Optional[bool] = False, ): try: if page is not None: @@ -45,10 +47,12 @@ def get_session_user_chat_list( skip = (page - 1) * limit return Chats.get_chat_title_id_list_by_user_id( - user.id, skip=skip, limit=limit + user.id, include_folders=include_folders, skip=skip, limit=limit ) else: - return Chats.get_chat_title_id_list_by_user_id(user.id) + return Chats.get_chat_title_id_list_by_user_id( + user.id, include_folders=include_folders + ) except Exception as e: log.exception(e) raise HTTPException( diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 31d7bce404..d4b88032e2 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,6 +1,7 @@ import logging from fastapi import APIRouter, Depends, Request, HTTPException from pydantic import BaseModel, ConfigDict +import aiohttp from typing import Optional @@ -17,6 +18,14 @@ from open_webui.utils.mcp.client import MCPClient from open_webui.env import SRC_LOG_LEVELS +from open_webui.utils.oauth import ( + get_discovery_urls, + get_oauth_client_info_with_dynamic_client_registration, + encrypt_data, + decrypt_data, + OAuthClientInformationFull, +) +from mcp.shared.auth import OAuthMetadata router = APIRouter() @@ -86,6 +95,43 @@ async def set_connections_config( } +class OAuthClientRegistrationForm(BaseModel): + url: str + client_id: str + client_name: Optional[str] = None + + +@router.post("/oauth/clients/register") +async def register_oauth_client( + request: Request, + form_data: OAuthClientRegistrationForm, + type: Optional[str] = None, + user=Depends(get_admin_user), +): + try: + oauth_client_id = form_data.client_id + if type: + oauth_client_id = f"{type}:{form_data.client_id}" + + oauth_client_info = ( + await get_oauth_client_info_with_dynamic_client_registration( + request, oauth_client_id, form_data.url + ) + ) + return { + "status": True, + "oauth_client_info": encrypt_data( + oauth_client_info.model_dump(mode="json") + ), + } + except Exception as e: + log.debug(f"Failed to register OAuth client: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to register OAuth client", + ) + + ############################ # ToolServers Config ############################ @@ -122,8 +168,29 @@ async def set_tool_servers_config( request.app.state.config.TOOL_SERVER_CONNECTIONS = [ connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS ] + await set_tool_servers(request) + for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: + server_type = connection.get("type", "openapi") + if server_type == "mcp": + server_id = connection.get("info", {}).get("id") + auth_type = connection.get("auth_type", "none") + if auth_type == "oauth_2.1" and server_id: + try: + oauth_client_info = connection.get("info", {}).get( + "oauth_client_info", "" + ) + oauth_client_info = decrypt_data(oauth_client_info) + + await request.app.state.oauth_client_manager.add_client( + f"{server_type}:{server_id}", + OAuthClientInformationFull(**oauth_client_info), + ) + except Exception as e: + log.debug(f"Failed to add OAuth client for MCP tool server: {e}") + continue + return { "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, } @@ -138,46 +205,79 @@ async def verify_tool_servers_config( """ try: if form_data.type == "mcp": - try: - client = MCPClient() - auth = None - headers = None + if form_data.auth_type == "oauth_2.1": + discovery_urls = get_discovery_urls(form_data.url) + async with aiohttp.ClientSession() as session: + async with session.get( + discovery_urls[0] + ) as oauth_server_metadata_response: + if oauth_server_metadata_response.status != 200: + raise HTTPException( + status_code=400, + detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}", + ) - token = None - if form_data.auth_type == "bearer": - token = form_data.key - elif form_data.auth_type == "session": - token = request.state.token.credentials - elif form_data.auth_type == "system_oauth": - try: - if request.cookies.get("oauth_session_id", None): - token = ( - await request.app.state.oauth_manager.get_oauth_token( + try: + oauth_server_metadata = OAuthMetadata.model_validate( + await oauth_server_metadata_response.json() + ) + return { + "status": True, + "oauth_server_metadata": oauth_server_metadata.model_dump( + mode="json" + ), + } + except Exception as e: + log.info( + f"Failed to parse OAuth 2.1 discovery document: {e}" + ) + raise HTTPException( + status_code=400, + detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}", + ) + + raise HTTPException( + status_code=400, + detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}", + ) + else: + try: + client = MCPClient() + headers = None + + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( user.id, request.cookies.get("oauth_session_id", None), ) - ) - except Exception as e: - pass + except Exception as e: + pass - if token: - headers = {"Authorization": f"Bearer {token}"} + if token: + headers = {"Authorization": f"Bearer {token}"} - await client.connect(form_data.url, auth=auth, headers=headers) - specs = await client.list_tool_specs() - return { - "status": True, - "specs": specs, - } - except Exception as e: - log.debug(f"Failed to create MCP client: {e}") - raise HTTPException( - status_code=400, - detail=f"Failed to create MCP client", - ) - finally: - if client: - await client.disconnect() + await client.connect(form_data.url, headers=headers) + specs = await client.list_tool_specs() + return { + "status": True, + "specs": specs, + } + except Exception as e: + log.debug(f"Failed to create MCP client: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to create MCP client", + ) + finally: + if client: + await client.disconnect() else: # openapi token = None if form_data.auth_type == "bearer": diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 36dbfee5c5..ddee71ea4d 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -262,15 +262,15 @@ async def update_folder_is_expanded_by_id( async def delete_folder_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - chat_delete_permission = has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS - ) - - if user.role != "admin" and not chat_delete_permission: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + if Chats.count_chats_by_folder_id_and_user_id(id, user.id): + chat_delete_permission = has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS ) + if user.role != "admin" and not chat_delete_permission: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) folder = Folders.get_folder_by_id_and_user_id(id, user.id) if folder: diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 202aa74ca4..c36e656d5f 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -431,8 +431,10 @@ async def update_function_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) - Functions.update_function_valves_by_id(id, valves.model_dump()) - return valves.model_dump() + + valves_dict = valves.model_dump(exclude_unset=True) + Functions.update_function_valves_by_id(id, valves_dict) + return valves_dict except Exception as e: log.exception(f"Error updating function values by id {id}: {e}") raise HTTPException( @@ -514,10 +516,11 @@ async def update_function_user_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) + user_valves_dict = user_valves.model_dump(exclude_unset=True) Functions.update_user_valves_by_id_and_user_id( - id, user.id, user_valves.model_dump() + id, user.id, user_valves_dict ) - return user_valves.model_dump() + return user_valves_dict except Exception as e: log.exception(f"Error updating function user valves by id {id}: {e}") raise HTTPException( diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 802a3e9924..059b3a23d7 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -514,6 +514,7 @@ async def image_generations( size = form_data.size width, height = tuple(map(int, size.split("x"))) + model = get_image_model(request) r = None try: @@ -531,11 +532,7 @@ async def image_generations( headers["X-OpenWebUI-User-Role"] = user.role data = { - "model": ( - request.app.state.config.IMAGE_GENERATION_MODEL - if request.app.state.config.IMAGE_GENERATION_MODEL != "" - else "dall-e-2" - ), + "model": model, "prompt": form_data.prompt, "n": form_data.n, "size": ( @@ -584,7 +581,6 @@ async def image_generations( headers["Content-Type"] = "application/json" headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY - model = get_image_model(request) data = { "instances": {"prompt": form_data.prompt}, "parameters": { @@ -640,7 +636,7 @@ async def image_generations( } ) res = await comfyui_generate_image( - request.app.state.config.IMAGE_GENERATION_MODEL, + model, form_data, user.id, request.app.state.config.COMFYUI_BASE_URL, diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 0ddf824efa..73b3a22725 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -45,6 +45,7 @@ from open_webui.retrieval.loaders.youtube import YoutubeLoader # Web search engines from open_webui.retrieval.web.main import SearchResult from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.web.ollama import search_ollama_cloud from open_webui.retrieval.web.brave import search_brave from open_webui.retrieval.web.kagi import search_kagi from open_webui.retrieval.web.mojeek import search_mojeek @@ -469,6 +470,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, + "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, "YACY_USERNAME": request.app.state.config.YACY_USERNAME, @@ -525,6 +527,7 @@ class WebConfig(BaseModel): WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = [] BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None + OLLAMA_CLOUD_WEB_SEARCH_API_KEY: Optional[str] = None SEARXNG_QUERY_URL: Optional[str] = None YACY_QUERY_URL: Optional[str] = None YACY_USERNAME: Optional[str] = None @@ -988,6 +991,9 @@ async def update_rag_config( request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = ( form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER ) + request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = ( + form_data.web.OLLAMA_CLOUD_WEB_SEARCH_API_KEY + ) request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME @@ -1139,6 +1145,7 @@ async def update_rag_config( "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, + "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, "YACY_USERNAME": request.app.state.config.YACY_USERNAME, @@ -1407,59 +1414,35 @@ def process_file( form_data: ProcessFileForm, user=Depends(get_verified_user), ): - try: + if user.role == "admin": file = Files.get_file_by_id(form_data.file_id) + else: + file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id) - collection_name = form_data.collection_name + if file: + try: - if collection_name is None: - collection_name = f"file-{file.id}" + collection_name = form_data.collection_name - if form_data.content: - # Update the content in the file - # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline) + if collection_name is None: + collection_name = f"file-{file.id}" - try: - # /files/{file_id}/data/content/update - VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}") - except: - # Audio file upload pipeline - pass + if form_data.content: + # Update the content in the file + # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline) - docs = [ - Document( - page_content=form_data.content.replace("
", "\n"), - metadata={ - **file.meta, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, - }, - ) - ] - - text_content = form_data.content - elif form_data.collection_name: - # Check if the file has already been processed and save the content - # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update - - result = VECTOR_DB_CLIENT.query( - collection_name=f"file-{file.id}", filter={"file_id": file.id} - ) - - if result is not None and len(result.ids[0]) > 0: - docs = [ - Document( - page_content=result.documents[0][idx], - metadata=result.metadatas[0][idx], + try: + # /files/{file_id}/data/content/update + VECTOR_DB_CLIENT.delete_collection( + collection_name=f"file-{file.id}" ) - for idx, id in enumerate(result.ids[0]) - ] - else: + except: + # Audio file upload pipeline + pass + docs = [ Document( - page_content=file.data.get("content", ""), + page_content=form_data.content.replace("
", "\n"), metadata={ **file.meta, "name": file.filename, @@ -1470,149 +1453,190 @@ def process_file( ) ] - text_content = file.data.get("content", "") - else: - # Process the file and save the content - # Usage: /files/ - file_path = file.path - if file_path: - file_path = Storage.get_file(file_path) - loader = Loader( - engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, - DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY, - DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL, - DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, - DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE, - DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR, - DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE, - DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, - DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, - DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES, - DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM, - DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, - EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, - DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL, - DOCLING_PARAMS={ - "do_ocr": request.app.state.config.DOCLING_DO_OCR, - "force_ocr": request.app.state.config.DOCLING_FORCE_OCR, - "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE, - "ocr_lang": request.app.state.config.DOCLING_OCR_LANG, - "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND, - "table_mode": request.app.state.config.DOCLING_TABLE_MODE, - "pipeline": request.app.state.config.DOCLING_PIPELINE, - "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, - "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, - "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, - }, - PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, - DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY, - ) - docs = loader.load( - file.filename, file.meta.get("content_type"), file_path + text_content = form_data.content + elif form_data.collection_name: + # Check if the file has already been processed and save the content + # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update + + result = VECTOR_DB_CLIENT.query( + collection_name=f"file-{file.id}", filter={"file_id": file.id} ) - docs = [ - Document( - page_content=doc.page_content, - metadata={ - **doc.metadata, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, - }, - ) - for doc in docs - ] - else: - docs = [ - Document( - page_content=file.data.get("content", ""), - metadata={ - **file.meta, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, - }, - ) - ] - text_content = " ".join([doc.page_content for doc in docs]) - - log.debug(f"text_content: {text_content}") - Files.update_file_data_by_id( - file.id, - {"content": text_content}, - ) - hash = calculate_sha256_string(text_content) - Files.update_file_hash_by_id(file.id, hash) - - if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: - Files.update_file_data_by_id(file.id, {"status": "completed"}) - return { - "status": True, - "collection_name": None, - "filename": file.filename, - "content": text_content, - } - else: - try: - result = save_docs_to_vector_db( - request, - docs=docs, - collection_name=collection_name, - metadata={ - "file_id": file.id, - "name": file.filename, - "hash": hash, - }, - add=(True if form_data.collection_name else False), - user=user, - ) - log.info(f"added {len(docs)} items to collection {collection_name}") - - if result: - Files.update_file_metadata_by_id( - file.id, - { - "collection_name": collection_name, - }, - ) - - Files.update_file_data_by_id( - file.id, - {"status": "completed"}, - ) - - return { - "status": True, - "collection_name": collection_name, - "filename": file.filename, - "content": text_content, - } + if result is not None and len(result.ids[0]) > 0: + docs = [ + Document( + page_content=result.documents[0][idx], + metadata=result.metadatas[0][idx], + ) + for idx, id in enumerate(result.ids[0]) + ] else: - raise Exception("Error saving document to vector database") - except Exception as e: - raise e + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] - except Exception as e: - log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + text_content = file.data.get("content", "") + else: + # Process the file and save the content + # Usage: /files/ + file_path = file.path + if file_path: + file_path = Storage.get_file(file_path) + loader = Loader( + engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, + DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY, + DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL, + DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, + DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE, + DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR, + DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE, + DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES, + DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM, + DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, + EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, + EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, + TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, + DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL, + DOCLING_PARAMS={ + "do_ocr": request.app.state.config.DOCLING_DO_OCR, + "force_ocr": request.app.state.config.DOCLING_FORCE_OCR, + "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE, + "ocr_lang": request.app.state.config.DOCLING_OCR_LANG, + "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND, + "table_mode": request.app.state.config.DOCLING_TABLE_MODE, + "pipeline": request.app.state.config.DOCLING_PIPELINE, + "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, + "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, + "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, + "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, + }, + PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, + DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY, + ) + docs = loader.load( + file.filename, file.meta.get("content_type"), file_path + ) + + docs = [ + Document( + page_content=doc.page_content, + metadata={ + **doc.metadata, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + for doc in docs + ] + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] + text_content = " ".join([doc.page_content for doc in docs]) + + log.debug(f"text_content: {text_content}") + Files.update_file_data_by_id( + file.id, + {"content": text_content}, ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), + hash = calculate_sha256_string(text_content) + Files.update_file_hash_by_id(file.id, hash) + + if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: + Files.update_file_data_by_id(file.id, {"status": "completed"}) + return { + "status": True, + "collection_name": None, + "filename": file.filename, + "content": text_content, + } + else: + try: + result = save_docs_to_vector_db( + request, + docs=docs, + collection_name=collection_name, + metadata={ + "file_id": file.id, + "name": file.filename, + "hash": hash, + }, + add=(True if form_data.collection_name else False), + user=user, + ) + log.info(f"added {len(docs)} items to collection {collection_name}") + + if result: + Files.update_file_metadata_by_id( + file.id, + { + "collection_name": collection_name, + }, + ) + + Files.update_file_data_by_id( + file.id, + {"status": "completed"}, + ) + + return { + "status": True, + "collection_name": collection_name, + "filename": file.filename, + "content": text_content, + } + else: + raise Exception("Error saving document to vector database") + except Exception as e: + raise e + + except Exception as e: + log.exception(e) + Files.update_file_data_by_id( + file.id, + {"status": "failed"}, ) + if "No pandoc was found" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + class ProcessTextForm(BaseModel): name: str @@ -1769,7 +1793,15 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: """ # TODO: add playwright to search the web - if engine == "searxng": + if engine == "ollama_cloud": + return search_ollama_cloud( + "https://ollama.com", + request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, + query, + request.app.state.config.WEB_SEARCH_RESULT_COUNT, + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + elif engine == "searxng": if request.app.state.config.SEARXNG_QUERY_URL: return search_searxng( request.app.state.config.SEARXNG_QUERY_URL, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 71c7069fd3..eb66a86825 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.tools import ( ToolForm, ToolModel, @@ -41,7 +42,15 @@ router = APIRouter() @router.get("/", response_model=list[ToolUserResponse]) async def get_tools(request: Request, user=Depends(get_verified_user)): - tools = Tools.get_tools() + tools = [ + ToolUserResponse( + **{ + **tool.model_dump(), + "has_user_valves": "class UserValves(BaseModel):" in tool.content, + } + ) + for tool in Tools.get_tools() + ] # OpenAPI Tool Servers for server in await get_tool_servers(request): @@ -72,6 +81,20 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): # MCP Tool Servers for server in request.app.state.config.TOOL_SERVER_CONNECTIONS: if server.get("type", "openapi") == "mcp": + server_id = server.get("info", {}).get("id") + auth_type = server.get("auth_type", "none") + + session_token = None + if auth_type == "oauth_2.1": + splits = server_id.split(":") + server_id = splits[-1] if len(splits) > 1 else server_id + + session_token = ( + await request.app.state.oauth_client_manager.get_oauth_token( + user.id, f"mcp:{server_id}" + ) + ) + tools.append( ToolUserResponse( **{ @@ -88,6 +111,13 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): ), "updated_at": int(time.time()), "created_at": int(time.time()), + **( + { + "authenticated": session_token is not None, + } + if auth_type == "oauth_2.1" + else {} + ), } ) ) @@ -486,8 +516,9 @@ async def update_tools_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) - Tools.update_tool_valves_by_id(id, valves.model_dump()) - return valves.model_dump() + valves_dict = valves.model_dump(exclude_unset=True) + Tools.update_tool_valves_by_id(id, valves_dict) + return valves_dict except Exception as e: log.exception(f"Failed to update tool valves by id {id}: {e}") raise HTTPException( @@ -562,10 +593,11 @@ async def update_tools_user_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) + user_valves_dict = user_valves.model_dump(exclude_unset=True) Tools.update_user_valves_by_id_and_user_id( - id, user.id, user_valves.model_dump() + id, user.id, user_valves_dict ) - return user_valves.model_dump() + return user_valves_dict except Exception as e: log.exception(f"Failed to update user valves by id {id}: {e}") raise HTTPException( diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 6215a6ac22..af48bebfb4 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -110,9 +110,13 @@ def has_access( type: str = "write", access_control: Optional[dict] = None, user_group_ids: Optional[Set[str]] = None, + strict: bool = True, ) -> bool: if access_control is None: - return type == "read" + if strict: + return type == "read" + else: + return True if user_group_ids is None: user_groups = Groups.get_groups_by_member_id(user_id) diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py index 2d352ead24..01df38886c 100644 --- a/backend/open_webui/utils/mcp/client.py +++ b/backend/open_webui/utils/mcp/client.py @@ -13,13 +13,9 @@ class MCPClient: self.session: Optional[ClientSession] = None self.exit_stack = AsyncExitStack() - async def connect( - self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None - ): + async def connect(self, url: str, headers: Optional[dict] = None): try: - self._streams_context = streamablehttp_client( - url, headers=headers, auth=auth - ) + self._streams_context = streamablehttp_client(url, headers=headers) transport = await self.exit_stack.enter_async_context(self._streams_context) read_stream, write_stream, _ = transport diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 52c78f199c..509f419b07 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse from starlette.responses import Response, StreamingResponse, JSONResponse +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.chats import Chats from open_webui.models.folders import Folders from open_webui.models.users import Users @@ -1047,6 +1048,22 @@ async def process_chat_payload(request, form_data, user, metadata, model): headers["Authorization"] = ( f"Bearer {oauth_token.get('access_token', '')}" ) + elif auth_type == "oauth_2.1": + try: + splits = server_id.split(":") + server_id = splits[-1] if len(splits) > 1 else server_id + + oauth_token = await request.app.state.oauth_client_manager.get_oauth_token( + user.id, f"mcp:{server_id}" + ) + + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + oauth_token = None mcp_client = MCPClient() await mcp_client.connect( @@ -1171,26 +1188,15 @@ async def process_chat_payload(request, form_data, user, metadata, model): raise Exception("No user message found") if context_string != "": - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model.get("owned_by") == "ollama": - form_data["messages"] = prepend_to_first_user_message_content( - rag_template( - request.app.state.config.RAG_TEMPLATE, - context_string, - prompt, - ), - form_data["messages"], - ) - else: - form_data["messages"] = add_or_update_system_message( - rag_template( - request.app.state.config.RAG_TEMPLATE, - context_string, - prompt, - ), - form_data["messages"], - ) + form_data["messages"] = add_or_update_user_message( + rag_template( + request.app.state.config.RAG_TEMPLATE, + context_string, + prompt, + ), + form_data["messages"], + append=False, + ) # If there are citations, add them to the data_items sources = [ diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 370cf26c48..e8cfa0d158 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -120,19 +120,20 @@ def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict] return get_system_message(messages), remove_system_message(messages) -def prepend_to_first_user_message_content( - content: str, messages: list[dict] -) -> list[dict]: - for message in messages: - if message["role"] == "user": - if isinstance(message["content"], list): - for item in message["content"]: - if item["type"] == "text": - item["text"] = f"{content}\n{item['text']}" - else: - message["content"] = f"{content}\n{message['content']}" - break - return messages +def update_message_content(message: dict, content: str, append: bool = True) -> dict: + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + if append: + item["text"] = f"{item['text']}\n{content}" + else: + item["text"] = f"{content}\n{item['text']}" + else: + if append: + message["content"] = f"{message['content']}\n{content}" + else: + message["content"] = f"{content}\n{message['content']}" + return message def add_or_update_system_message( @@ -148,10 +149,7 @@ def add_or_update_system_message( """ if messages and messages[0].get("role") == "system": - if append: - messages[0]["content"] = f"{messages[0]['content']}\n{content}" - else: - messages[0]["content"] = f"{content}\n{messages[0]['content']}" + messages[0] = update_message_content(messages[0], content, append) else: # Insert at the beginning messages.insert(0, {"role": "system", "content": content}) @@ -159,7 +157,7 @@ def add_or_update_system_message( return messages -def add_or_update_user_message(content: str, messages: list[dict]): +def add_or_update_user_message(content: str, messages: list[dict], append: bool = True): """ Adds a new user message at the end of the messages list or updates the existing user message at the end. @@ -170,7 +168,7 @@ def add_or_update_user_message(content: str, messages: list[dict]): """ if messages and messages[-1].get("role") == "user": - messages[-1]["content"] = f"{messages[-1]['content']}\n{content}" + messages[-1] = update_message_content(messages[-1], content, append) else: # Insert at the end messages.append({"role": "user", "content": content}) @@ -178,6 +176,16 @@ def add_or_update_user_message(content: str, messages: list[dict]): return messages +def prepend_to_first_user_message_content( + content: str, messages: list[dict] +) -> list[dict]: + for message in messages: + if message["role"] == "user": + message = update_message_content(message, content, append=False) + break + return messages + + def append_or_update_assistant_message(content: str, messages: list[dict]): """ Adds a new assistant message at the end of the messages list diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index ee3ba79990..9399241853 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1,7 +1,9 @@ import base64 +import hashlib import logging import mimetypes import sys +import urllib import uuid import json from datetime import datetime, timedelta @@ -9,6 +11,9 @@ from datetime import datetime, timedelta import re import fnmatch import time +import secrets +from cryptography.fernet import Fernet + import aiohttp from authlib.integrations.starlette_client import OAuth @@ -18,6 +23,7 @@ from fastapi import ( status, ) from starlette.responses import RedirectResponse +from typing import Optional from open_webui.models.auths import Auths @@ -56,11 +62,27 @@ from open_webui.env import ( WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, ENABLE_OAUTH_ID_TOKEN_COOKIE, + OAUTH_CLIENT_INFO_ENCRYPTION_KEY, ) from open_webui.utils.misc import parse_duration from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.webhook import post_webhook +from mcp.shared.auth import ( + OAuthClientMetadata, + OAuthMetadata, +) + + +class OAuthClientInformationFull(OAuthClientMetadata): + issuer: Optional[str] = None # URL of the OAuth server that issued this client + + client_id: str + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + + from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -89,6 +111,42 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN +FERNET = None + +if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44: + key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest() + OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes) +else: + OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode() + +try: + FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) +except Exception as e: + log.error(f"Error initializing Fernet with provided key: {e}") + raise + + +def encrypt_data(data) -> str: + """Encrypt data for storage""" + try: + data_json = json.dumps(data) + encrypted = FERNET.encrypt(data_json.encode()).decode() + return encrypted + except Exception as e: + log.error(f"Error encrypting data: {e}") + raise + + +def decrypt_data(data: str): + """Decrypt data from storage""" + try: + decrypted = FERNET.decrypt(data.encode()).decode() + return json.loads(decrypted) + except Exception as e: + log.error(f"Error decrypting data: {e}") + raise + + def is_in_blocked_groups(group_name: str, groups: list) -> bool: """ Check if a group name matches any blocked pattern. @@ -133,6 +191,412 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool: return False +def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]: + parsed = urllib.parse.urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + return parsed, base_url + + +def get_discovery_urls(server_url) -> list[str]: + urls = [] + parsed, base_url = get_parsed_and_base_url(server_url) + + urls.append( + urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server") + ) + urls.append(urllib.parse.urljoin(base_url, "/.well-known/openid-configuration")) + + return urls + + +# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration. +# This is not currently supported. +async def get_oauth_client_info_with_dynamic_client_registration( + request, + client_id: str, + oauth_server_url: str, + oauth_server_key: Optional[str] = None, +) -> OAuthClientInformationFull: + try: + oauth_server_metadata = None + oauth_server_metadata_url = None + + redirect_base_url = ( + str(request.app.state.config.WEBUI_URL or request.base_url) + ).rstrip("/") + + oauth_client_metadata = OAuthClientMetadata( + client_name="Open WebUI", + redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + # Attempt to fetch OAuth server metadata to get registration endpoint & scopes + discovery_urls = get_discovery_urls(oauth_server_url) + for url in discovery_urls: + async with aiohttp.ClientSession() as session: + async with session.get( + url, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as oauth_server_metadata_response: + if oauth_server_metadata_response.status == 200: + try: + oauth_server_metadata = OAuthMetadata.model_validate( + await oauth_server_metadata_response.json() + ) + oauth_server_metadata_url = url + if ( + oauth_client_metadata.scope is None + and oauth_server_metadata.scopes_supported is not None + ): + oauth_client_metadata.scope = " ".join( + oauth_server_metadata.scopes_supported + ) + break + except Exception as e: + log.error(f"Error parsing OAuth metadata from {url}: {e}") + continue + + registration_url = None + if oauth_server_metadata and oauth_server_metadata.registration_endpoint: + registration_url = str(oauth_server_metadata.registration_endpoint) + else: + _, base_url = get_parsed_and_base_url(oauth_server_url) + registration_url = urllib.parse.urljoin(base_url, "/register") + + registration_data = oauth_client_metadata.model_dump( + exclude_none=True, + mode="json", + by_alias=True, + ) + + # Perform dynamic client registration and return client info + async with aiohttp.ClientSession() as session: + async with session.post( + registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as oauth_client_registration_response: + try: + registration_response_json = ( + await oauth_client_registration_response.json() + ) + oauth_client_info = OAuthClientInformationFull.model_validate( + { + **registration_response_json, + **{"issuer": oauth_server_metadata_url}, + } + ) + log.info( + f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}" + ) + return oauth_client_info + except Exception as e: + error_text = None + try: + error_text = await oauth_client_registration_response.text() + log.error( + f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}" + ) + except Exception as e: + pass + + log.error(f"Error parsing client registration response: {e}") + raise Exception( + f"Dynamic client registration failed: {error_text}" + if error_text + else "Error parsing client registration response" + ) + raise Exception("Dynamic client registration failed") + except Exception as e: + log.error(f"Exception during dynamic client registration: {e}") + raise e + + +class OAuthClientManager: + def __init__(self, app): + self.oauth = OAuth() + self.app = app + self.clients = {} + + def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull): + self.clients[client_id] = { + "client": self.oauth.register( + name=client_id, + client_id=oauth_client_info.client_id, + client_secret=oauth_client_info.client_secret, + client_kwargs=( + {"scope": oauth_client_info.scope} + if oauth_client_info.scope + else {} + ), + server_metadata_url=( + oauth_client_info.issuer if oauth_client_info.issuer else None + ), + ), + "client_info": oauth_client_info, + } + return self.clients[client_id] + + def remove_client(self, client_id): + if client_id in self.clients: + del self.clients[client_id] + log.info(f"Removed OAuth client {client_id}") + return True + + def get_client(self, client_id): + client = self.clients.get(client_id) + return client["client"] if client else None + + def get_client_info(self, client_id): + client = self.clients.get(client_id) + return client["client_info"] if client else None + + def get_server_metadata_url(self, client_id): + if client_id in self.clients: + client = self.clients[client_id] + return ( + client.server_metadata_url + if hasattr(client, "server_metadata_url") + else None + ) + return None + + async def get_oauth_token( + self, user_id: str, client_id: str, force_refresh: bool = False + ): + """ + Get a valid OAuth token for the user, automatically refreshing if needed. + + Args: + user_id: The user ID + client_id: The OAuth client ID (provider) + force_refresh: Force token refresh even if current token appears valid + + Returns: + dict: OAuth token data with access_token, or None if no valid token available + """ + try: + # Get the OAuth session + session = OAuthSessions.get_session_by_provider_and_user_id( + client_id, user_id + ) + if not session: + log.warning( + f"No OAuth session found for user {user_id}, client_id {client_id}" + ) + return None + + if force_refresh or datetime.now() + timedelta( + minutes=5 + ) >= datetime.fromtimestamp(session.expires_at): + log.debug( + f"Token refresh needed for user {user_id}, client_id {session.provider}" + ) + refreshed_token = await self._refresh_token(session) + if refreshed_token: + return refreshed_token + else: + log.warning( + f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}" + ) + OAuthSessions.delete_session_by_id(session.id) + return None + return session.token + + except Exception as e: + log.error(f"Error getting OAuth token for user {user_id}: {e}") + return None + + async def _refresh_token(self, session) -> dict: + """ + Refresh an OAuth token if needed, with concurrency protection. + + Args: + session: The OAuth session object + + Returns: + dict: Refreshed token data, or None if refresh failed + """ + try: + # Perform the actual refresh + refreshed_token = await self._perform_token_refresh(session) + + if refreshed_token: + # Update the session with new token data + session = OAuthSessions.update_session_by_id( + session.id, refreshed_token + ) + log.info(f"Successfully refreshed token for session {session.id}") + return session.token + else: + log.error(f"Failed to refresh token for session {session.id}") + return None + + except Exception as e: + log.error(f"Error refreshing token for session {session.id}: {e}") + return None + + async def _perform_token_refresh(self, session) -> dict: + """ + Perform the actual OAuth token refresh. + + Args: + session: The OAuth session object + + Returns: + dict: New token data, or None if refresh failed + """ + client_id = session.provider + token_data = session.token + + if not token_data.get("refresh_token"): + log.warning(f"No refresh token available for session {session.id}") + return None + + try: + client = self.get_client(client_id) + if not client: + log.error(f"No OAuth client found for provider {client_id}") + return None + + token_endpoint = None + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.get( + self.get_server_metadata_url(client_id) + ) as r: + if r.status == 200: + openid_data = await r.json() + token_endpoint = openid_data.get("token_endpoint") + else: + log.error( + f"Failed to fetch OpenID configuration for client_id {client_id}" + ) + if not token_endpoint: + log.error(f"No token endpoint found for client_id {client_id}") + return None + + # Prepare refresh request + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": token_data["refresh_token"], + "client_id": client.client_id, + } + if hasattr(client, "client_secret") and client.client_secret: + refresh_data["client_secret"] = client.client_secret + + # Make refresh request + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.post( + token_endpoint, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + if r.status == 200: + new_token_data = await r.json() + + # Merge with existing token data (preserve refresh_token if not provided) + if "refresh_token" not in new_token_data: + new_token_data["refresh_token"] = token_data[ + "refresh_token" + ] + + # Add timestamp for tracking + new_token_data["issued_at"] = datetime.now().timestamp() + + # Calculate expires_at if we have expires_in + if ( + "expires_in" in new_token_data + and "expires_at" not in new_token_data + ): + new_token_data["expires_at"] = int( + datetime.now().timestamp() + + new_token_data["expires_in"] + ) + + log.debug(f"Token refresh successful for client_id {client_id}") + return new_token_data + else: + error_text = await r.text() + log.error( + f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}" + ) + return None + + except Exception as e: + log.error(f"Exception during token refresh for client_id {client_id}: {e}") + return None + + async def handle_authorize(self, request, client_id: str) -> RedirectResponse: + client = self.get_client(client_id) + if client is None: + raise HTTPException(404) + + client_info = self.get_client_info(client_id) + if client_info is None: + raise HTTPException(404) + + redirect_uri = ( + client_info.redirect_uris[0] if client_info.redirect_uris else None + ) + return await client.authorize_redirect(request, str(redirect_uri)) + + async def handle_callback(self, request, client_id: str, user_id: str, response): + client = self.get_client(client_id) + if client is None: + raise HTTPException(404) + + error_message = None + try: + token = await client.authorize_access_token(request) + if token: + try: + # Add timestamp for tracking + token["issued_at"] = datetime.now().timestamp() + + # Calculate expires_at if we have expires_in + if "expires_in" in token and "expires_at" not in token: + token["expires_at"] = ( + datetime.now().timestamp() + token["expires_in"] + ) + + # Clean up any existing sessions for this user/client_id first + sessions = OAuthSessions.get_sessions_by_user_id(user_id) + for session in sessions: + if session.provider == client_id: + OAuthSessions.delete_session_by_id(session.id) + + session = OAuthSessions.create_session( + user_id=user_id, + provider=client_id, + token=token, + ) + log.info( + f"Stored OAuth session server-side for user {user_id}, client_id {client_id}" + ) + except Exception as e: + error_message = "Failed to store OAuth session server-side" + log.error(f"Failed to store OAuth session server-side: {e}") + else: + error_message = "Failed to obtain OAuth token" + log.warning(error_message) + except Exception as e: + error_message = "OAuth callback error" + log.warning(f"OAuth callback error: {e}") + + redirect_url = ( + str(request.app.state.config.WEBUI_URL or request.base_url) + ).rstrip("/") + + if error_message: + log.debug(error_message) + redirect_url = f"{redirect_url}/?error={error_message}" + return RedirectResponse(url=redirect_url, headers=response.headers) + + response = RedirectResponse(url=redirect_url, headers=response.headers) + return response + + class OAuthManager: def __init__(self, app): self.oauth = OAuth() @@ -191,8 +655,10 @@ class OAuthManager: return refreshed_token else: log.warning( - f"Token refresh failed for user {user_id}, provider {session.provider}" + f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}" ) + OAuthSessions.delete_session_by_id(session.id) + return None return session.token @@ -252,9 +718,10 @@ class OAuthManager: log.error(f"No OAuth client found for provider {provider}") return None + server_metadata_url = self.get_server_metadata_url(provider) token_endpoint = None async with aiohttp.ClientSession(trust_env=True) as session_http: - async with session_http.get(client.gserver_metadata_url) as r: + async with session_http.get(server_metadata_url) as r: if r.status == 200: openid_data = await r.json() token_endpoint = openid_data.get("token_endpoint") @@ -301,7 +768,7 @@ class OAuthManager: "expires_in" in new_token_data and "expires_at" not in new_token_data ): - new_token_data["expires_at"] = ( + new_token_data["expires_at"] = int( datetime.now().timestamp() + new_token_data["expires_in"] ) @@ -574,7 +1041,7 @@ class OAuthManager: raise HTTPException(404) # If the provider has a custom redirect URL, use that, otherwise automatically generate one redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( - "oauth_callback", provider=provider + "oauth_login_callback", provider=provider ) client = self.get_client(provider) if client is None: @@ -791,9 +1258,9 @@ class OAuthManager: else ERROR_MESSAGES.DEFAULT("Error during OAuth process") ) - redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url) - if redirect_base_url.endswith("/"): - redirect_base_url = redirect_base_url[:-1] + redirect_base_url = ( + str(request.app.state.config.WEBUI_URL or request.base_url) + ).rstrip("/") redirect_url = f"{redirect_base_url}/auth" if error_message: diff --git a/backend/requirements.txt b/backend/requirements.txt index 1b14ac1429..23bb7710f0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,13 +9,14 @@ python-jose==3.4.0 passlib[bcrypt]==1.7.4 cryptography -requests==2.32.4 +requests==2.32.5 aiohttp==3.12.15 async-timeout aiocache aiofiles starlette-compress==1.6.0 httpx[socks,http2,zstd,cli,brotli]==0.28.1 +starsessions[redis]==2.2.1 sqlalchemy==2.0.38 alembic==1.14.0 @@ -43,13 +44,13 @@ asgiref==3.8.1 # AI libraries openai anthropic -google-genai==1.32.0 +google-genai==1.38.0 google-generativeai==0.8.5 tiktoken mcp==1.14.1 -langchain==0.3.26 -langchain-community==0.3.27 +langchain==0.3.27 +langchain-community==0.3.29 fake-useragent==2.2.0 chromadb==1.0.20 diff --git a/pyproject.toml b/pyproject.toml index 09fcce07fb..aa2825fa69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "PyJWT[crypto]==2.10.1", "authlib==1.6.3", - "requests==2.32.4", + "requests==2.32.5", "aiohttp==3.12.15", "async-timeout", "aiocache", @@ -51,11 +51,11 @@ dependencies = [ "openai", "anthropic", - "google-genai==1.32.0", + "google-genai==1.38.0", "google-generativeai==0.8.5", - "langchain==0.3.26", - "langchain-community==0.3.27", + "langchain==0.3.27", + "langchain-community==0.3.29", "fake-useragent==2.2.0", "chromadb==1.0.20", diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index b1e7d5f23b..59d8600771 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -77,7 +77,11 @@ export const importChat = async ( return res; }; -export const getChatList = async (token: string = '', page: number | null = null) => { +export const getChatList = async ( + token: string = '', + page: number | null = null, + include_folders: boolean = false +) => { let error = null; const searchParams = new URLSearchParams(); @@ -85,6 +89,10 @@ export const getChatList = async (token: string = '', page: number | null = null searchParams.append('page', `${page}`); } + if (include_folders) { + searchParams.append('include_folders', 'true'); + } + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/?${searchParams.toString()}`, { method: 'GET', headers: { diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index ef983e63bf..c6cfdd2b2b 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -1,4 +1,4 @@ -import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import type { Banner } from '$lib/types'; export const importConfig = async (token: string, config) => { @@ -202,6 +202,52 @@ export const verifyToolServerConnection = async (token: string, connection: obje return res; }; +type RegisterOAuthClientForm = { + url: string; + client_id: string; + client_name?: string; +}; + +export const registerOAuthClient = async ( + token: string, + formData: RegisterOAuthClientForm, + type: null | string = null +) => { + let error = null; + + const searchParams = type ? `?type=${type}` : ''; + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register${searchParams}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...formData + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.error(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getOAuthClientAuthorizationUrl = (clientId: string, type: null | string = null) => { + const oauthClientId = type ? `${type}:${clientId}` : clientId; + return `${WEBUI_BASE_URL}/oauth/clients/${oauthClientId}/authorize`; +}; + export const getCodeExecutionConfig = async (token: string) => { let error = null; diff --git a/src/lib/components/AddToolServerModal.svelte b/src/lib/components/AddToolServerModal.svelte index 01c87010ef..114c05f834 100644 --- a/src/lib/components/AddToolServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -13,7 +13,7 @@ import Switch from '$lib/components/common/Switch.svelte'; import Tags from './common/Tags.svelte'; import { getToolServerData } from '$lib/apis'; - import { verifyToolServerConnection } from '$lib/apis/configs'; + import { verifyToolServerConnection, registerOAuthClient } from '$lib/apis/configs'; import AccessControl from './workspace/common/AccessControl.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; import XMark from '$lib/components/icons/XMark.svelte'; @@ -41,10 +41,47 @@ let name = ''; let description = ''; - let enable = true; + let oauthClientInfo = null; + let enable = true; let loading = false; + const registerOAuthClientHandler = async () => { + if (url === '') { + toast.error($i18n.t('Please enter a valid URL')); + return; + } + + if (id === '') { + toast.error($i18n.t('Please enter a valid ID')); + return; + } + + const res = await registerOAuthClient( + localStorage.token, + { + url: url, + client_id: id + }, + 'mcp' + ).catch((err) => { + toast.error($i18n.t('Registration failed')); + return null; + }); + + if (res) { + toast.warning( + $i18n.t( + 'Please save the connection to persist the OAuth client information and do not change the ID' + ) + ); + toast.success($i18n.t('Registration successful')); + + console.debug('Registration successful', res); + oauthClientInfo = res?.oauth_client_info ?? null; + } + }; + const verifyHandler = async () => { if (url === '') { toast.error($i18n.t('Please enter a valid URL')); @@ -106,6 +143,12 @@ return; } + if (type === 'mcp' && auth_type === 'oauth_2.1' && !oauthClientInfo) { + toast.error($i18n.t('Please register the OAuth client')); + loading = false; + return; + } + const connection = { url, path, @@ -119,7 +162,8 @@ info: { id: id, name: name, - description: description + description: description, + ...(oauthClientInfo ? { oauth_client_info: oauthClientInfo } : {}) } }; @@ -139,6 +183,7 @@ id = ''; name = ''; description = ''; + oauthClientInfo = null; enable = true; accessControl = null; @@ -156,6 +201,7 @@ id = connection.info?.id ?? ''; name = connection.info?.name ?? ''; description = connection.info?.description ?? ''; + oauthClientInfo = connection.info?.oauth_client_info ?? null; enable = connection.config?.enable ?? true; accessControl = connection.config?.access_control ?? null; @@ -227,25 +273,6 @@ {/if} - {#if type === 'mcp'} -
- - {$i18n.t('Warning')}: - - {$i18n.t( - 'MCP support is experimental and its specification changes often, which can lead to incompatibilities. OpenAPI specification support is directly maintained by the Open WebUI team, making it the more reliable option for compatibility.' - )} - - {$i18n.t('Read more →')} -
- {/if} -
@@ -333,11 +360,52 @@
- +
+
+
+ {$i18n.t('Auth')} +
+
+ + {#if auth_type === 'oauth_2.1'} +
+
+ + + +
+ + {#if !oauthClientInfo} +
+ {$i18n.t('Not Registered')} +
+ {:else} +
+ {$i18n.t('Registered')} +
+ {/if} +
+ {/if} +
@@ -353,6 +421,9 @@ {#if !direct} + {#if type === 'mcp'} + + {/if} {/if}
@@ -382,6 +453,12 @@ > {$i18n.t('Forwards system user OAuth access token to authenticate')}
+ {:else if auth_type === 'oauth_2.1'} +
+ {$i18n.t('Uses ​OAuth 2.1 Dynamic Client Registration')} +
{/if}
@@ -470,6 +547,25 @@ {/if}
+ {#if type === 'mcp'} +
+ + {$i18n.t('Warning')}: + + {$i18n.t( + 'MCP support is experimental and its specification changes often, which can lead to incompatibilities. OpenAPI specification support is directly maintained by the Open WebUI team, making it the more reliable option for compatibility.' + )} + + {$i18n.t('Read more →')} +
+ {/if} +
{#if edit} +
+ +
+
{ + e.preventDefault(); + submitHandler(); + }} + > +
+ +
+ + + +
+ +
+
+
+
+ diff --git a/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte b/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte index 5a6ce96cc4..077f97d416 100644 --- a/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte +++ b/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte @@ -7,7 +7,7 @@ dayjs.extend(relativeTime); import { tick, getContext, onMount, onDestroy } from 'svelte'; - import { removeLastWordFromString, isValidHttpUrl } from '$lib/utils'; + import { removeLastWordFromString, isValidHttpUrl, isYoutubeUrl } from '$lib/utils'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import DocumentPage from '$lib/components/icons/DocumentPage.svelte'; import Database from '$lib/components/icons/Database.svelte'; @@ -36,7 +36,7 @@ : items), ...(query.startsWith('http') - ? query.startsWith('https://www.youtube.com') || query.startsWith('https://youtu.be') + ? isYoutubeUrl(query) ? [{ type: 'youtube', name: query, description: query }] : [ { @@ -228,7 +228,7 @@ {/if} {/each} - {#if query.startsWith('https://www.youtube.com') || query.startsWith('https://youtu.be')} + {#if isYoutubeUrl(query)} + +
+ {/if} +
import hljs from 'highlight.js'; - + import { toast } from 'svelte-sonner'; import { getContext, onMount, tick, onDestroy } from 'svelte'; + import { config } from '$lib/stores'; + + import PyodideWorker from '$lib/workers/pyodide.worker?worker'; + import { executeCode } from '$lib/apis/utils'; import { copyToClipboard, renderMermaidDiagram } from '$lib/utils'; import 'highlight.js/styles/github-dark.min.css'; - import PyodideWorker from '$lib/workers/pyodide.worker?worker'; + import CodeEditor from '$lib/components/common/CodeEditor.svelte'; import SvgPanZoom from '$lib/components/common/SVGPanZoom.svelte'; - import { config } from '$lib/stores'; - import { executeCode } from '$lib/apis/utils'; - import { toast } from 'svelte-sonner'; + import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronUpDown from '$lib/components/icons/ChevronUpDown.svelte'; import CommandLine from '$lib/components/icons/CommandLine.svelte'; @@ -480,19 +482,17 @@ {#if !collapsed} {#if edit} - {#await import('$lib/components/common/CodeEditor.svelte') then { default: CodeEditor }} - { - saveCode(); - }} - onChange={(value) => { - _code = value; - }} - /> - {/await} + { + saveCode(); + }} + onChange={(value) => { + _code = value; + }} + /> {:else}
-	import { onDestroy, onMount } from 'svelte';
+	import { onDestroy, onMount, getContext } from 'svelte';
 	import panzoom, { type PanZoom } from 'panzoom';
 
 	import fileSaver from 'file-saver';
@@ -11,6 +11,8 @@
 	export let src = '';
 	export let alt = '';
 
+	const i18n = getContext('i18n');
+
 	let mounted = false;
 
 	let previewElement = null;
@@ -100,9 +102,10 @@
 
 							const mimeType = blob.type || 'image/png';
 							// create file name based on the MIME type, alt should be a valid file name with extension
-							const fileName = alt
-								? `${alt.replaceAll('.', '')}.${mimeType.split('/')[1]}`
-								: 'download.png';
+							const fileName = `${$i18n
+								.t('Generated Image')
+								.toLowerCase()
+								.replace(/ /g, '_')}.${mimeType.split('/')[1]}`;
 
 							// Use FileSaver to save the blob
 							saveAs(blob, fileName);
@@ -119,9 +122,10 @@
 									const blobWithType = new Blob([blob], { type: mimeType });
 
 									// create file name based on the MIME type, alt should be a valid file name with extension
-									const fileName = alt
-										? `${alt.replaceAll('.', '')}.${mimeType.split('/')[1]}`
-										: 'download.png';
+									const fileName = `${$i18n
+										.t('Generated Image')
+										.toLowerCase()
+										.replace(/ /g, '_')}.${mimeType.split('/')[1]}`;
 
 									// Use FileSaver to save the blob
 									saveAs(blobWithType, fileName);
@@ -146,9 +150,10 @@
 									const blobWithType = new Blob([blob], { type: mimeType });
 
 									// create file name based on the MIME type, alt should be a valid file name with extension
-									const fileName = alt
-										? `${alt.replaceAll('.', '')}.${mimeType.split('/')[1]}`
-										: 'download.png';
+									const fileName = `${$i18n
+										.t('Generated Image')
+										.toLowerCase()
+										.replace(/ /g, '_')}.${mimeType.split('/')[1]}`;
 
 									// Use FileSaver to save the blob
 									saveAs(blobWithType, fileName);
diff --git a/src/lib/components/common/RichTextInput.svelte b/src/lib/components/common/RichTextInput.svelte
index 0c46d0f296..b30beb0a7e 100644
--- a/src/lib/components/common/RichTextInput.svelte
+++ b/src/lib/components/common/RichTextInput.svelte
@@ -149,10 +149,15 @@
 	export let onChange = (e) => {};
 
 	// create a lowlight instance with all languages loaded
-	const lowlight = createLowlight(hljs.listLanguages().reduce((obj, lang) => {
-		obj[lang] = () => hljs.getLanguage(lang);
-		return obj;
-	}, {} as Record));
+	const lowlight = createLowlight(
+		hljs.listLanguages().reduce(
+			(obj, lang) => {
+				obj[lang] = () => hljs.getLanguage(lang);
+				return obj;
+			},
+			{} as Record
+		)
+	);
 
 	export let editor: Editor | null = null;
 
@@ -163,7 +168,7 @@
 	export let documentId = '';
 
 	export let className = 'input-prose';
-	export let placeholder = 'Type here...';
+	export let placeholder = $i18n.t('Type here...');
 	let _placeholder = placeholder;
 
 	$: if (placeholder !== _placeholder) {
@@ -501,9 +506,14 @@
 
 	export const focus = () => {
 		if (editor) {
-			editor.view.focus();
-			// Scroll to the current selection
-			editor.view.dispatch(editor.view.state.tr.scrollIntoView());
+			try {
+				editor.view?.focus();
+				// Scroll to the current selection
+				editor.view?.dispatch(editor.view.state.tr.scrollIntoView());
+			} catch (e) {
+				// sometimes focusing throws an error, ignore
+				console.warn('Error focusing editor', e);
+			}
 		}
 	};
 
@@ -679,7 +689,7 @@
 					link: link
 				}),
 				...(dragHandle ? [ListItemDragHandle] : []),
-				Placeholder.configure({ placeholder: () => _placeholder }),
+				Placeholder.configure({ placeholder: () => _placeholder, showOnlyWhenEditable: false }),
 				SelectionDecoration,
 
 				...(richText
@@ -1113,4 +1123,9 @@
 	
{/if} -
+
diff --git a/src/lib/components/layout/Overlay/AccountPending.svelte b/src/lib/components/layout/Overlay/AccountPending.svelte index 9197933b60..0c4dc8c2d5 100644 --- a/src/lib/components/layout/Overlay/AccountPending.svelte +++ b/src/lib/components/layout/Overlay/AccountPending.svelte @@ -1,4 +1,7 @@