mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 05:45:19 +00:00
feat: oauth2.1 mcp integration
This commit is contained in:
parent
972be4eda5
commit
77e971dd9f
10 changed files with 248 additions and 53 deletions
|
|
@ -473,7 +473,12 @@ from open_webui.utils.auth import (
|
||||||
get_verified_user,
|
get_verified_user,
|
||||||
)
|
)
|
||||||
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
||||||
from open_webui.utils.oauth import OAuthManager
|
from open_webui.utils.oauth import (
|
||||||
|
OAuthManager,
|
||||||
|
OAuthClientManager,
|
||||||
|
decrypt_data,
|
||||||
|
OAuthClientInformationFull,
|
||||||
|
)
|
||||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||||
from open_webui.utils.redis import get_redis_connection
|
from open_webui.utils.redis import get_redis_connection
|
||||||
|
|
||||||
|
|
@ -603,9 +608,14 @@ app = FastAPI(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For Open WebUI OIDC/OAuth2
|
||||||
oauth_manager = OAuthManager(app)
|
oauth_manager = OAuthManager(app)
|
||||||
app.state.oauth_manager = oauth_manager
|
app.state.oauth_manager = oauth_manager
|
||||||
|
|
||||||
|
# For Integrations
|
||||||
|
oauth_client_manager = OAuthClientManager(app)
|
||||||
|
app.state.oauth_client_manager = oauth_client_manager
|
||||||
|
|
||||||
app.state.instance_id = None
|
app.state.instance_id = None
|
||||||
app.state.config = AppConfig(
|
app.state.config = AppConfig(
|
||||||
redis_url=REDIS_URL,
|
redis_url=REDIS_URL,
|
||||||
|
|
@ -1881,6 +1891,24 @@ async def get_current_usage(user=Depends(get_verified_user)):
|
||||||
# OAuth Login & Callback
|
# OAuth Login & Callback
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1
|
||||||
|
if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
|
||||||
|
for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||||
|
if tool_server_connection.get("type", "openapi") == "mcp":
|
||||||
|
server_id = tool_server_connection.get("info", {}).get("id")
|
||||||
|
auth_type = tool_server_connection.get("auth_type", "none")
|
||||||
|
if server_id and auth_type == "oauth_2.1":
|
||||||
|
oauth_client_info = tool_server_connection.get("info", {}).get(
|
||||||
|
"oauth_client_info"
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth_client_info = decrypt_data(oauth_client_info)
|
||||||
|
app.state.oauth_client_manager.add_client(
|
||||||
|
f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# SessionMiddleware is used by authlib for oauth
|
# SessionMiddleware is used by authlib for oauth
|
||||||
if len(OAUTH_PROVIDERS) > 0:
|
if len(OAUTH_PROVIDERS) > 0:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1913,6 +1941,31 @@ if len(OAUTH_PROVIDERS) > 0:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/oauth/clients/{client_id}/authorize")
|
||||||
|
async def oauth_client_authorize(
|
||||||
|
client_id: str,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
return await oauth_client_manager.handle_authorize(request, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/oauth/clients/{client_id}/callback")
|
||||||
|
async def oauth_client_callback(
|
||||||
|
client_id: str,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
return await oauth_client_manager.handle_callback(
|
||||||
|
request,
|
||||||
|
client_id=client_id,
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/oauth/{provider}/login")
|
@app.get("/oauth/{provider}/login")
|
||||||
async def oauth_login(provider: str, request: Request):
|
async def oauth_login(provider: str, request: Request):
|
||||||
return await oauth_manager.handle_login(request, provider)
|
return await oauth_manager.handle_login(request, provider)
|
||||||
|
|
@ -1924,8 +1977,9 @@ async def oauth_login(provider: str, request: Request):
|
||||||
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
|
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
|
||||||
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
|
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
|
||||||
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
||||||
@app.get("/oauth/{provider}/callback")
|
@app.get("/oauth/{provider}/callback") # Legacy endpoint
|
||||||
async def oauth_callback(provider: str, request: Request, response: Response):
|
@app.get("/oauth/{provider}/login/callback")
|
||||||
|
async def oauth_login_callback(provider: str, request: Request, response: Response):
|
||||||
return await oauth_manager.handle_callback(request, provider, response)
|
return await oauth_manager.handle_callback(request, provider, response)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,26 @@ class OAuthSessionTable:
|
||||||
log.error(f"Error getting OAuth session by ID: {e}")
|
log.error(f"Error getting OAuth session by ID: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_session_by_provider_and_user_id(
|
||||||
|
self, provider: str, user_id: str
|
||||||
|
) -> Optional[OAuthSessionModel]:
|
||||||
|
"""Get OAuth session by provider and user ID"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
session = (
|
||||||
|
db.query(OAuthSession)
|
||||||
|
.filter_by(provider=provider, user_id=user_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if session:
|
||||||
|
session.token = self._decrypt_token(session.token)
|
||||||
|
return OAuthSessionModel.model_validate(session)
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth session by provider and user ID: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
|
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
|
||||||
"""Get all OAuth sessions for a user"""
|
"""Get all OAuth sessions for a user"""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,9 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.utils.oauth import (
|
from open_webui.utils.oauth import (
|
||||||
get_discovery_urls,
|
get_discovery_urls,
|
||||||
get_oauth_client_info_with_dynamic_client_registration,
|
get_oauth_client_info_with_dynamic_client_registration,
|
||||||
encrypt_token,
|
encrypt_data,
|
||||||
|
decrypt_data,
|
||||||
|
OAuthClientInformationFull,
|
||||||
)
|
)
|
||||||
from mcp.shared.auth import OAuthMetadata
|
from mcp.shared.auth import OAuthMetadata
|
||||||
|
|
||||||
|
|
@ -103,17 +105,22 @@ class OAuthClientRegistrationForm(BaseModel):
|
||||||
async def register_oauth_client(
|
async def register_oauth_client(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: OAuthClientRegistrationForm,
|
form_data: OAuthClientRegistrationForm,
|
||||||
|
type: Optional[str] = None,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
oauth_client_id = form_data.client_id
|
||||||
|
if type:
|
||||||
|
oauth_client_id = f"{type}:{form_data.client_id}"
|
||||||
|
|
||||||
oauth_client_info = (
|
oauth_client_info = (
|
||||||
await get_oauth_client_info_with_dynamic_client_registration(
|
await get_oauth_client_info_with_dynamic_client_registration(
|
||||||
request, form_data.url
|
request, oauth_client_id, form_data.url
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"oauth_client_info": encrypt_token(
|
"oauth_client_info": encrypt_data(
|
||||||
oauth_client_info.model_dump(mode="json")
|
oauth_client_info.model_dump(mode="json")
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
@ -161,8 +168,25 @@ async def set_tool_servers_config(
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||||
]
|
]
|
||||||
|
|
||||||
await set_tool_servers(request)
|
await set_tool_servers(request)
|
||||||
|
|
||||||
|
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||||
|
server_type = connection.get("type", "openapi")
|
||||||
|
if server_type == "mcp":
|
||||||
|
server_id = connection.get("info", {}).get("id")
|
||||||
|
auth_type = connection.get("auth_type", "none")
|
||||||
|
if auth_type == "oauth_2.1" and server_id:
|
||||||
|
try:
|
||||||
|
oauth_client_info = decrypt_data(oauth_client_info)
|
||||||
|
await request.app.state.oauth_client_manager.add_client(
|
||||||
|
f"{server_type}:{server_id}",
|
||||||
|
OAuthClientInformationFull(**oauth_client_info),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
from open_webui.models.tools import (
|
from open_webui.models.tools import (
|
||||||
ToolForm,
|
ToolForm,
|
||||||
ToolModel,
|
ToolModel,
|
||||||
|
|
@ -80,6 +81,24 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
# MCP Tool Servers
|
# MCP Tool Servers
|
||||||
for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||||
if server.get("type", "openapi") == "mcp":
|
if server.get("type", "openapi") == "mcp":
|
||||||
|
server_id = server.get("info", {}).get("id")
|
||||||
|
auth_type = server.get("auth_type", "none")
|
||||||
|
|
||||||
|
session_token = None
|
||||||
|
if auth_type == "oauth_2.1":
|
||||||
|
splits = server_id.split(":")
|
||||||
|
server_id = splits[-1] if len(splits) > 1 else server_id
|
||||||
|
|
||||||
|
session_token = (
|
||||||
|
await request.app.state.oauth_client_manager.get_oauth_token(
|
||||||
|
user.id, f"mcp:{server_id}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("User ID:", user.id)
|
||||||
|
print("Server ID:", server_id)
|
||||||
|
print("MCP Session Token:", session_token)
|
||||||
|
|
||||||
tools.append(
|
tools.append(
|
||||||
ToolUserResponse(
|
ToolUserResponse(
|
||||||
**{
|
**{
|
||||||
|
|
@ -96,6 +115,13 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
),
|
),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"authenticated": session_token is not None,
|
||||||
|
}
|
||||||
|
if auth_type == "oauth_2.1"
|
||||||
|
else {}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
|
||||||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.folders import Folders
|
from open_webui.models.folders import Folders
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
|
|
@ -1047,6 +1048,22 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
f"Bearer {oauth_token.get('access_token', '')}"
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
)
|
)
|
||||||
|
elif auth_type == "oauth_2.1":
|
||||||
|
try:
|
||||||
|
splits = server_id.split(":")
|
||||||
|
server_id = splits[-1] if len(splits) > 1 else server_id
|
||||||
|
|
||||||
|
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
|
||||||
|
user.id, f"mcp:{server_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if oauth_token:
|
||||||
|
headers["Authorization"] = (
|
||||||
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
oauth_token = None
|
||||||
|
|
||||||
mcp_client = MCPClient()
|
mcp_client = MCPClient()
|
||||||
await mcp_client.connect(
|
await mcp_client.connect(
|
||||||
|
|
|
||||||
|
|
@ -126,24 +126,24 @@ except Exception as e:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def encrypt_token(token) -> str:
|
def encrypt_data(data) -> str:
|
||||||
"""Encrypt OAuth tokens for storage"""
|
"""Encrypt data for storage"""
|
||||||
try:
|
try:
|
||||||
token_json = json.dumps(token)
|
data_json = json.dumps(data)
|
||||||
encrypted = FERNET.encrypt(token_json.encode()).decode()
|
encrypted = FERNET.encrypt(data_json.encode()).decode()
|
||||||
return encrypted
|
return encrypted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error encrypting tokens: {e}")
|
log.error(f"Error encrypting data: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def decrypt_token(token: str):
|
def decrypt_data(data: str):
|
||||||
"""Decrypt OAuth tokens from storage"""
|
"""Decrypt data from storage"""
|
||||||
try:
|
try:
|
||||||
decrypted = FERNET.decrypt(token.encode()).decode()
|
decrypted = FERNET.decrypt(data.encode()).decode()
|
||||||
return json.loads(decrypted)
|
return json.loads(decrypted)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error decrypting tokens: {e}")
|
log.error(f"Error decrypting data: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -212,7 +212,10 @@ def get_discovery_urls(server_url) -> list[str]:
|
||||||
# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
|
# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
|
||||||
# This is not currently supported.
|
# This is not currently supported.
|
||||||
async def get_oauth_client_info_with_dynamic_client_registration(
|
async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
request, oauth_server_url, oauth_server_key: Optional[str] = None
|
request,
|
||||||
|
client_id: str,
|
||||||
|
oauth_server_url: str,
|
||||||
|
oauth_server_key: Optional[str] = None,
|
||||||
) -> OAuthClientInformationFull:
|
) -> OAuthClientInformationFull:
|
||||||
try:
|
try:
|
||||||
oauth_server_metadata = None
|
oauth_server_metadata = None
|
||||||
|
|
@ -221,9 +224,10 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
redirect_base_url = (
|
redirect_base_url = (
|
||||||
str(request.app.state.config.WEBUI_URL or request.base_url)
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
||||||
).rstrip("/")
|
).rstrip("/")
|
||||||
|
|
||||||
oauth_client_metadata = OAuthClientMetadata(
|
oauth_client_metadata = OAuthClientMetadata(
|
||||||
client_name="Open WebUI",
|
client_name="Open WebUI",
|
||||||
redirect_uris=[f"{redirect_base_url}/oauth/callback"],
|
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
|
||||||
grant_types=["authorization_code", "refresh_token"],
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
response_types=["code"],
|
response_types=["code"],
|
||||||
token_endpoint_auth_method="client_secret_post",
|
token_endpoint_auth_method="client_secret_post",
|
||||||
|
|
@ -315,23 +319,22 @@ class OAuthClientManager:
|
||||||
self.clients = {}
|
self.clients = {}
|
||||||
|
|
||||||
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
|
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
|
||||||
if client_id not in self.clients:
|
self.clients[client_id] = {
|
||||||
self.clients[client_id] = {
|
"client": self.oauth.register(
|
||||||
"client": self.oauth.register(
|
name=client_id,
|
||||||
name=client_id,
|
client_id=oauth_client_info.client_id,
|
||||||
client_id=oauth_client_info.client_id,
|
client_secret=oauth_client_info.client_secret,
|
||||||
client_secret=oauth_client_info.client_secret,
|
client_kwargs=(
|
||||||
client_kwargs=(
|
{"scope": oauth_client_info.scope}
|
||||||
{"scope": oauth_client_info.scope}
|
if oauth_client_info.scope
|
||||||
if oauth_client_info.scope
|
else {}
|
||||||
else {}
|
|
||||||
),
|
|
||||||
server_metadata_url=(
|
|
||||||
oauth_client_info.issuer if oauth_client_info.issuer else None
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"client_info": oauth_client_info,
|
server_metadata_url=(
|
||||||
}
|
oauth_client_info.issuer if oauth_client_info.issuer else None
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"client_info": oauth_client_info,
|
||||||
|
}
|
||||||
return self.clients[client_id]
|
return self.clients[client_id]
|
||||||
|
|
||||||
def remove_client(self, client_id):
|
def remove_client(self, client_id):
|
||||||
|
|
@ -359,7 +362,7 @@ class OAuthClientManager:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_oauth_token(
|
async def get_oauth_token(
|
||||||
self, user_id: str, session_id: str, force_refresh: bool = False
|
self, user_id: str, client_id: str, force_refresh: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get a valid OAuth token for the user, automatically refreshing if needed.
|
Get a valid OAuth token for the user, automatically refreshing if needed.
|
||||||
|
|
@ -374,10 +377,12 @@ class OAuthClientManager:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get the OAuth session
|
# Get the OAuth session
|
||||||
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
|
session = OAuthSessions.get_session_by_provider_and_user_id(
|
||||||
|
client_id, user_id
|
||||||
|
)
|
||||||
if not session:
|
if not session:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"No OAuth session found for user {user_id}, session {session_id}"
|
f"No OAuth session found for user {user_id}, client_id {client_id}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -392,8 +397,9 @@ class OAuthClientManager:
|
||||||
return refreshed_token
|
return refreshed_token
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Token refresh failed for user {user_id}, client_id {session.provider}"
|
f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
|
||||||
)
|
)
|
||||||
|
OAuthSessions.delete_session_by_id(session.id)
|
||||||
return None
|
return None
|
||||||
return session.token
|
return session.token
|
||||||
|
|
||||||
|
|
@ -533,7 +539,7 @@ class OAuthClientManager:
|
||||||
redirect_uri = (
|
redirect_uri = (
|
||||||
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
||||||
)
|
)
|
||||||
return await client.authorize_redirect(request, redirect_uri)
|
return await client.authorize_redirect(request, str(redirect_uri))
|
||||||
|
|
||||||
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id)
|
||||||
|
|
@ -565,7 +571,6 @@ class OAuthClientManager:
|
||||||
provider=client_id,
|
provider=client_id,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
|
f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
|
||||||
)
|
)
|
||||||
|
|
@ -579,16 +584,17 @@ class OAuthClientManager:
|
||||||
error_message = "OAuth callback error"
|
error_message = "OAuth callback error"
|
||||||
log.warning(f"OAuth callback error: {e}")
|
log.warning(f"OAuth callback error: {e}")
|
||||||
|
|
||||||
redirect_base_url = (
|
redirect_url = (
|
||||||
str(request.app.state.config.WEBUI_URL or request.base_url)
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
||||||
).rstrip("/")
|
).rstrip("/")
|
||||||
redirect_url = f"{redirect_base_url}/auth"
|
|
||||||
|
|
||||||
if error_message:
|
if error_message:
|
||||||
redirect_url = f"{redirect_url}?error={error_message}"
|
log.debug(error_message)
|
||||||
|
redirect_url = f"{redirect_url}/?error={error_message}"
|
||||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
|
||||||
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
class OAuthManager:
|
class OAuthManager:
|
||||||
|
|
@ -649,8 +655,10 @@ class OAuthManager:
|
||||||
return refreshed_token
|
return refreshed_token
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Token refresh failed for user {user_id}, provider {session.provider}"
|
f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
|
||||||
)
|
)
|
||||||
|
OAuthSessions.delete_session_by_id(session.id)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
return session.token
|
return session.token
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { WEBUI_API_BASE_URL } from '$lib/constants';
|
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
|
||||||
import type { Banner } from '$lib/types';
|
import type { Banner } from '$lib/types';
|
||||||
|
|
||||||
export const importConfig = async (token: string, config) => {
|
export const importConfig = async (token: string, config) => {
|
||||||
|
|
@ -208,10 +208,15 @@ type RegisterOAuthClientForm = {
|
||||||
client_name?: string;
|
client_name?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const registerOAuthClient = async (token: string, formData: RegisterOAuthClientForm) => {
|
export const registerOAuthClient = async (
|
||||||
|
token: string,
|
||||||
|
formData: RegisterOAuthClientForm,
|
||||||
|
type: null | string = null
|
||||||
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register`, {
|
const searchParams = type ? `?type=${type}` : '';
|
||||||
|
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register${searchParams}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|
@ -238,6 +243,11 @@ export const registerOAuthClient = async (token: string, formData: RegisterOAuth
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const getOAuthClientAuthorizationUrl = (clientId: string, type: null | string = null) => {
|
||||||
|
const oauthClientId = type ? `${type}:${clientId}` : clientId;
|
||||||
|
return `${WEBUI_BASE_URL}/oauth/clients/${oauthClientId}/authorize`;
|
||||||
|
};
|
||||||
|
|
||||||
export const getCodeExecutionConfig = async (token: string) => {
|
export const getCodeExecutionConfig = async (token: string) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,16 +57,26 @@
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const res = await registerOAuthClient(localStorage.token, {
|
const res = await registerOAuthClient(
|
||||||
url: url,
|
localStorage.token,
|
||||||
client_id: id
|
{
|
||||||
}).catch((err) => {
|
url: url,
|
||||||
|
client_id: id
|
||||||
|
},
|
||||||
|
'mcp'
|
||||||
|
).catch((err) => {
|
||||||
toast.error($i18n.t('Registration failed'));
|
toast.error($i18n.t('Registration failed'));
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (res) {
|
if (res) {
|
||||||
|
toast.warning(
|
||||||
|
$i18n.t(
|
||||||
|
'Please save the connection to persist the OAuth client information and do not change the ID'
|
||||||
|
)
|
||||||
|
);
|
||||||
toast.success($i18n.t('Registration successful'));
|
toast.success($i18n.t('Registration successful'));
|
||||||
|
|
||||||
console.debug('Registration successful', res);
|
console.debug('Registration successful', res);
|
||||||
oauthClientInfo = res?.oauth_client_info ?? null;
|
oauthClientInfo = res?.oauth_client_info ?? null;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@
|
||||||
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
|
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
|
||||||
import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte';
|
import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte';
|
||||||
import ValvesModal from '$lib/components/workspace/common/ValvesModal.svelte';
|
import ValvesModal from '$lib/components/workspace/common/ValvesModal.svelte';
|
||||||
|
import { getOAuthClientAuthorizationUrl } from '$lib/apis/configs';
|
||||||
|
import { partition } from 'd3-hierarchy';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
|
|
@ -321,11 +323,25 @@
|
||||||
|
|
||||||
{#each Object.keys(tools) as toolId}
|
{#each Object.keys(tools) as toolId}
|
||||||
<button
|
<button
|
||||||
class="flex w-full justify-between gap-2 items-center px-3 py-1.5 text-sm cursor-pointer rounded-xl hover:bg-gray-50 dark:hover:bg-gray-800/50"
|
class="relative flex w-full justify-between gap-2 items-center px-3 py-1.5 text-sm cursor-pointer rounded-xl hover:bg-gray-50 dark:hover:bg-gray-800/50"
|
||||||
on:click={() => {
|
on:click={(e) => {
|
||||||
tools[toolId].enabled = !tools[toolId].enabled;
|
if (!(tools[toolId]?.authenticated ?? true)) {
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
let parts = toolId.split(':');
|
||||||
|
let serverId = parts?.at(-1) ?? toolId;
|
||||||
|
|
||||||
|
const authUrl = getOAuthClientAuthorizationUrl(serverId, 'mcp');
|
||||||
|
window.open(authUrl, '_blank', 'noopener');
|
||||||
|
} else {
|
||||||
|
tools[toolId].enabled = !tools[toolId].enabled;
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
{#if !(tools[toolId]?.authenticated ?? true)}
|
||||||
|
<!-- make it slighly darker and not clickable -->
|
||||||
|
<div class="absolute inset-0 opacity-50 rounded-xl cursor-not-allowed z-10" />
|
||||||
|
{/if}
|
||||||
<div class="flex-1 truncate">
|
<div class="flex-1 truncate">
|
||||||
<div class="flex flex-1 gap-2 items-center">
|
<div class="flex flex-1 gap-2 items-center">
|
||||||
<Tooltip content={tools[toolId]?.name ?? ''} placement="top">
|
<Tooltip content={tools[toolId]?.name ?? ''} placement="top">
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,15 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
|
import { onMount } from 'svelte';
|
||||||
|
import { toast } from 'svelte-sonner';
|
||||||
|
|
||||||
import Chat from '$lib/components/chat/Chat.svelte';
|
import Chat from '$lib/components/chat/Chat.svelte';
|
||||||
|
import { page } from '$app/stores';
|
||||||
|
|
||||||
|
onMount(() => {
|
||||||
|
if ($page.url.searchParams.get('error')) {
|
||||||
|
toast.error($page.url.searchParams.get('error') || 'An unknown error occurred.');
|
||||||
|
}
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<Chat />
|
<Chat />
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue