From 217f4daef09b36d3d4cc4681e11d3ebd9984a1a5 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 8 Sep 2025 18:05:43 +0400 Subject: [PATCH] feat: server-side OAuth token management system Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com> --- backend/open_webui/env.py | 13 +- backend/open_webui/main.py | 1 + .../38d63c18f30f_add_oauth_session_table.py | 52 ++++ backend/open_webui/models/oauth_sessions.py | 247 ++++++++++++++++++ backend/open_webui/routers/auths.py | 32 +-- backend/open_webui/utils/auth.py | 96 +++---- backend/open_webui/utils/middleware.py | 9 + backend/open_webui/utils/oauth.py | 247 ++++++++++++++++-- backend/open_webui/utils/tools.py | 15 ++ src/lib/components/AddServerModal.svelte | 7 + 10 files changed, 627 insertions(+), 92 deletions(-) create mode 100644 backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py create mode 100644 backend/open_webui/models/oauth_sessions.py diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index f72d827afc..b4fdc97d82 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -465,8 +465,17 @@ ENABLE_COMPRESSION_MIDDLEWARE = ( os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true" ) -ENABLE_OAUTH_SESSION_TOKENS_COOKIES = ( - os.environ.get("ENABLE_OAUTH_SESSION_TOKENS_COOKIES", "True").lower() == "true" +#################################### +# OAUTH Configuration +#################################### + + +ENABLE_OAUTH_ID_TOKEN_COOKIE = ( + os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true" +) + +OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get( + "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY ) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7decfcd83b..de7dcae086 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -592,6 +592,7 @@ app = FastAPI( ) oauth_manager = OAuthManager(app) +app.state.oauth_manager = oauth_manager app.state.instance_id = None app.state.config = AppConfig( diff --git a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py new file mode 100644 index 0000000000..8ead6db6d4 --- /dev/null +++ b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py @@ -0,0 +1,52 @@ +"""Add oauth_session table + +Revision ID: 38d63c18f30f +Revises: 3af16a1c9fb6 +Create Date: 2025-09-08 14:19:59.583921 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "38d63c18f30f" +down_revision: Union[str, None] = "3af16a1c9fb6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create oauth_session table + op.create_table( + "oauth_session", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("provider", sa.Text(), nullable=False), + sa.Column("token", sa.Text(), nullable=False), + sa.Column("expires_at", sa.BigInteger(), nullable=False), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + ) + + # Create indexes for better performance + op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"]) + op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"]) + op.create_index( + "idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"] + ) + + +def downgrade() -> None: + # Drop indexes first + op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session") + op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session") + op.drop_index("idx_oauth_session_user_id", table_name="oauth_session") + + # Drop the table + op.drop_table("oauth_session") diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py new file mode 100644 index 0000000000..b0b5aa29a6 --- /dev/null +++ b/backend/open_webui/models/oauth_sessions.py @@ -0,0 +1,247 @@ +import time +import logging +import uuid +from typing import Optional, List +import base64 +import hashlib +import json + +from cryptography.fernet import Fernet + +from open_webui.internal.db import Base, get_db +from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text, Index + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# DB MODEL +#################### + + +class OAuthSession(Base): + __tablename__ = "oauth_session" + + id = Column(Text, primary_key=True) + user_id = Column(Text, nullable=False) + provider = Column(Text, nullable=False) + token = Column( + Text, nullable=False + ) # JSON with access_token, id_token, refresh_token + expires_at = Column(BigInteger, nullable=False) + created_at = Column(BigInteger, nullable=False) + updated_at = Column(BigInteger, nullable=False) + + # Add indexes for better performance + __table_args__ = ( + Index("idx_oauth_session_user_id", "user_id"), + Index("idx_oauth_session_expires_at", "expires_at"), + Index("idx_oauth_session_user_provider", "user_id", "provider"), + ) + + +class OAuthSessionModel(BaseModel): + id: str + user_id: str + provider: str + token: dict + expires_at: int # timestamp in epoch + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + model_config = ConfigDict(from_attributes=True) + + +#################### +# Forms +#################### + + +class OAuthSessionResponse(BaseModel): + id: str + user_id: str + provider: str + expires_at: int + + +class OAuthSessionTable: + def __init__(self): + self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY + if not self.encryption_key: + raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set") + + # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes) + if len(self.encryption_key) != 44: + key_bytes = hashlib.sha256(self.encryption_key.encode()).digest() + self.encryption_key = base64.urlsafe_b64encode(key_bytes) + else: + self.encryption_key = self.encryption_key.encode() + + try: + self.fernet = Fernet(self.encryption_key) + except Exception as e: + log.error(f"Error initializing Fernet with provided key: {e}") + raise + + def _encrypt_token(self, token) -> str: + """Encrypt OAuth tokens for storage""" + try: + token_json = json.dumps(token) + encrypted = self.fernet.encrypt(token_json.encode()).decode() + return encrypted + except Exception as e: + log.error(f"Error encrypting tokens: {e}") + raise + + def _decrypt_token(self, token: str): + """Decrypt OAuth tokens from storage""" + try: + decrypted = self.fernet.decrypt(token.encode()).decode() + return json.loads(decrypted) + except Exception as e: + log.error(f"Error decrypting tokens: {e}") + raise + + def create_session( + self, + user_id: str, + provider: str, + token: dict, + ) -> Optional[OAuthSessionModel]: + """Create a new OAuth session""" + try: + with get_db() as db: + current_time = int(time.time()) + id = str(uuid.uuid4()) + + result = OAuthSession( + **{ + "id": id, + "user_id": user_id, + "provider": provider, + "token": self._encrypt_token(token), + "expires_at": token.get("expires_at"), + "created_at": current_time, + "updated_at": current_time, + } + ) + + db.add(result) + db.commit() + db.refresh(result) + + if result: + result.token = token # Return decrypted token + return OAuthSessionModel.model_validate(result) + else: + return None + except Exception as e: + log.error(f"Error creating OAuth session: {e}") + return None + + def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]: + """Get OAuth session by ID""" + try: + with get_db() as db: + session = db.query(OAuthSession).filter_by(id=session_id).first() + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error getting OAuth session by ID: {e}") + return None + + def get_session_by_id_and_user_id( + self, session_id: str, user_id: str + ) -> Optional[OAuthSessionModel]: + """Get OAuth session by ID and user ID""" + try: + with get_db() as db: + session = ( + db.query(OAuthSession) + .filter_by(id=session_id, user_id=user_id) + .first() + ) + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + ) + return None + except Exception as e: + log.error(f"Error getting OAuth session by ID: {e}") + return None + + def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: + """Get all OAuth sessions for a user""" + try: + with get_db() as db: + sessions = db.query(OAuthSession).filter_by(user_id=user_id).all() + + + results = [] + for session in sessions: + session.token = self._decrypt_token(session.token) + results.append(OAuthSessionModel.model_validate(session)) + + return results + + except Exception as e: + log.error(f"Error getting OAuth sessions by user ID: {e}") + return [] + + def update_session_by_id( + self, session_id: str, token: dict + ) -> Optional[OAuthSessionModel]: + """Update OAuth session tokens""" + try: + with get_db() as db: + current_time = int(time.time()) + + db.query(OAuthSession).filter_by(id=session_id).update( + { + "token": self._encrypt_token(token), + "expires_at": token.get("expires_at"), + "updated_at": current_time, + } + ) + db.commit() + session = db.query(OAuthSession).filter_by(id=session_id).first() + + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error updating OAuth session tokens: {e}") + return None + + def delete_session_by_id(self, session_id: str) -> bool: + """Delete an OAuth session""" + try: + with get_db() as db: + result = db.query(OAuthSession).filter_by(id=session_id).delete() + db.commit() + return result > 0 + except Exception as e: + log.error(f"Error deleting OAuth session: {e}") + return False + + def delete_sessions_by_user_id(self, user_id: str) -> bool: + """Delete all OAuth sessions for a user""" + try: + with get_db() as db: + result = db.query(OAuthSession).filter_by(user_id=user_id).delete() + db.commit() + return True + except Exception as e: + log.error(f"Error deleting OAuth sessions by user ID: {e}") + return False + + +OAuthSessions = OAuthSessionTable() diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 524edf373d..d044b4a168 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -19,6 +19,7 @@ from open_webui.models.auths import ( ) from open_webui.models.users import Users, UpdateProfileForm from open_webui.models.groups import Groups +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( @@ -28,7 +29,6 @@ from open_webui.env import ( WEBUI_AUTH_TRUSTED_GROUPS_HEADER, WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, - ENABLE_OAUTH_SESSION_TOKENS_COOKIES, WEBUI_AUTH_SIGNOUT_REDIRECT_URL, ENABLE_INITIAL_ADMIN_SIGNUP, SRC_LOG_LEVELS, @@ -678,24 +678,27 @@ async def signout(request: Request, response: Response): response.delete_cookie("token") response.delete_cookie("oui-session") - if ENABLE_OAUTH_SIGNUP.value: - # TODO: update this to use oauth_session_tokens in User Object - oauth_id_token = request.cookies.get("oauth_id_token") + oauth_session_id = request.cookies.get("oauth_session_id") + if oauth_session_id: + response.delete_cookie("oauth_session_id") - if oauth_id_token and OPENID_PROVIDER_URL.value: + session = OAuthSessions.get_session_by_id(oauth_session_id) + oauth_server_metadata_url = ( + request.app.state.oauth_manager.get_server_metadata_url(session.provider) + if session + else None + ) or OPENID_PROVIDER_URL.value + + if session and oauth_server_metadata_url: + oauth_id_token = session.token.get("id_token") try: async with ClientSession(trust_env=True) as session: - async with session.get(OPENID_PROVIDER_URL.value) as r: + async with session.get(oauth_server_metadata_url) as r: if r.status == 200: openid_data = await r.json() logout_url = openid_data.get("end_session_endpoint") if logout_url: - if ENABLE_OAUTH_SESSION_TOKENS_COOKIES: - response.delete_cookie("oauth_id_token") - response.delete_cookie("oauth_access_token") - response.delete_cookie("oauth_refresh_token") - return JSONResponse( status_code=200, content={ @@ -710,15 +713,14 @@ async def signout(request: Request, response: Response): headers=response.headers, ) else: - raise HTTPException( - status_code=r.status, - detail="Failed to fetch OpenID configuration", - ) + raise Exception("Failed to fetch OpenID configuration") + except Exception as e: log.error(f"OpenID signout error: {str(e)}") raise HTTPException( status_code=500, detail="Failed to sign out from the OpenID provider.", + headers=response.headers, ) if WEBUI_AUTH_SIGNOUT_REDIRECT_URL: diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 33b377ad03..19994bafbd 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -261,61 +261,63 @@ def get_current_user( return user # auth by jwt token - try: - data = decode_token(token) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", - ) - if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) - if user is None: + try: + try: + data = decode_token(token) + except Exception as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.INVALID_TOKEN, + detail="Invalid token", ) - else: - if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: - trusted_email = request.headers.get( - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, "" - ).lower() - if trusted_email and user.email != trusted_email: - # Delete the token cookie - response.delete_cookie("token") - # Delete OAuth token if present - if request.cookies.get("oauth_id_token"): - response.delete_cookie("oauth_id_token") - if request.cookies.get("oauth_access_token"): - response.delete_cookie("oauth_access_token") - if request.cookies.get("oauth_refresh_token"): - response.delete_cookie("oauth_refresh_token") + if data is not None and "id" in data: + user = Users.get_user_by_id(data["id"]) + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + else: + if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: + trusted_email = request.headers.get( + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, "" + ).lower() + if trusted_email and user.email != trusted_email: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User mismatch. Please sign in again.", + ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="User mismatch. Please sign in again.", + # Add user info to current span + current_span = trace.get_current_span() + if current_span: + current_span.set_attribute("client.user.id", user.id) + current_span.set_attribute("client.user.email", user.email) + current_span.set_attribute("client.user.role", user.role) + current_span.set_attribute("client.auth.type", "jwt") + + # Refresh the user's last active timestamp asynchronously + # to prevent blocking the request + if background_tasks: + background_tasks.add_task( + Users.update_user_last_active_by_id, user.id ) + return user + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + except Exception as e: + # Delete the token cookie + if request.cookies.get("token"): + response.delete_cookie("token") + # Delete OAuth session if present + if request.cookies.get("oauth_session_id"): + response.delete_cookie("oauth_session_id") - # Add user info to current span - current_span = trace.get_current_span() - if current_span: - current_span.set_attribute("client.user.id", user.id) - current_span.set_attribute("client.user.email", user.email) - current_span.set_attribute("client.user.role", user.role) - current_span.set_attribute("client.auth.type", "jwt") - - # Refresh the user's last active timestamp asynchronously - # to prevent blocking the request - if background_tasks: - background_tasks.add_task(Users.update_user_last_active_by_id, user.id) - return user - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) + raise e def get_current_user_by_api_key(api_key: str): diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 463f52d0af..27b0f11290 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -815,6 +815,14 @@ async def process_chat_payload(request, form_data, user, metadata, model): event_emitter = get_event_emitter(metadata) event_call = get_event_call(metadata) + oauth_token = None + try: + oauth_token = await request.app.state.oauth_manager.get_oauth_token( + user.id, request.cookies.get("oauth_session_id", None) + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + extra_params = { "__event_emitter__": event_emitter, "__event_call__": event_call, @@ -822,6 +830,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__metadata__": metadata, "__request__": request, "__model__": model, + "__oauth_token__": oauth_token, } # Initialize events to store additional event to be sent to the client diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 4411f40e3b..55ee3eee54 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -4,9 +4,11 @@ import mimetypes import sys import uuid import json +from datetime import datetime, timedelta import re import fnmatch +import time import aiohttp from authlib.integrations.starlette_client import OAuth @@ -17,8 +19,12 @@ from fastapi import ( ) from starlette.responses import RedirectResponse + from open_webui.models.auths import Auths +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.users import Users + + from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm from open_webui.config import ( DEFAULT_USER_ROLE, @@ -49,7 +55,7 @@ from open_webui.env import ( WEBUI_NAME, WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, - ENABLE_OAUTH_SESSION_TOKENS_COOKIES, + ENABLE_OAUTH_ID_TOKEN_COOKIE, ) from open_webui.utils.misc import parse_duration from open_webui.utils.auth import get_password_hash, create_token @@ -131,11 +137,187 @@ class OAuthManager: def __init__(self, app): self.oauth = OAuth() self.app = app + + self._clients = {} for _, provider_config in OAUTH_PROVIDERS.items(): provider_config["register"](self.oauth) def get_client(self, provider_name): - return self.oauth.create_client(provider_name) + if provider_name not in self._clients: + self._clients[provider_name] = self.oauth.create_client(provider_name) + return self._clients[provider_name] + + def get_server_metadata_url(self, provider_name): + if provider_name in self._clients: + client = self._clients[provider_name] + return ( + client.server_metadata_url + if hasattr(client, "server_metadata_url") + else None + ) + return None + + def get_oauth_token( + self, user_id: str, session_id: str, force_refresh: bool = False + ): + """ + Get a valid OAuth token for the user, automatically refreshing if needed. + + Args: + user_id: The user ID + provider: Optional provider name. If None, gets the most recent session. + force_refresh: Force token refresh even if current token appears valid + + Returns: + dict: OAuth token data with access_token, or None if no valid token available + """ + try: + # Get the OAuth session + session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id) + if not session: + log.warning( + f"No OAuth session found for user {user_id}, session {session_id}" + ) + return None + + if force_refresh or datetime.now() + timedelta( + minutes=5 + ) >= datetime.fromtimestamp(session.expires_at): + log.debug( + f"Token refresh needed for user {user_id}, provider {session.provider}" + ) + refreshed_token = self._refresh_token(session) + if refreshed_token: + return refreshed_token + else: + log.warning( + f"Token refresh failed for user {user_id}, provider {session.provider}" + ) + return None + return session.token + + except Exception as e: + log.error(f"Error getting OAuth token for user {user_id}: {e}") + return None + + async def _refresh_token(self, session) -> dict: + """ + Refresh an OAuth token if needed, with concurrency protection. + + Args: + session: The OAuth session object + + Returns: + dict: Refreshed token data, or None if refresh failed + """ + try: + # Perform the actual refresh + refreshed_token = await self._perform_token_refresh(session) + + if refreshed_token: + # Update the session with new token data + session = OAuthSessions.update_session_by_id( + session.id, refreshed_token + ) + log.info(f"Successfully refreshed token for session {session.id}") + return session.token + else: + log.error(f"Failed to refresh token for session {session.id}") + return None + + except Exception as e: + log.error(f"Error refreshing token for session {session.id}: {e}") + return None + + async def _perform_token_refresh(self, session) -> dict: + """ + Perform the actual OAuth token refresh. + + Args: + session: The OAuth session object + + Returns: + dict: New token data, or None if refresh failed + """ + provider = session.provider + token_data = session.token + + if not token_data.get("refresh_token"): + log.warning(f"No refresh token available for session {session.id}") + return None + + try: + client = self.get_client(provider) + if not client: + log.error(f"No OAuth client found for provider {provider}") + return None + + token_endpoint = None + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.get(client.gserver_metadata_url) as r: + if r.status == 200: + openid_data = await r.json() + token_endpoint = openid_data.get("token_endpoint") + else: + log.error( + f"Failed to fetch OpenID configuration for provider {provider}" + ) + if not token_endpoint: + log.error(f"No token endpoint found for provider {provider}") + return None + + # Prepare refresh request + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": token_data["refresh_token"], + "client_id": client.client_id, + } + # Add client_secret if available (some providers require it) + if hasattr(client, "client_secret") and client.client_secret: + refresh_data["client_secret"] = client.client_secret + + # Make refresh request + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.post( + token_endpoint, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + if r.status == 200: + new_token_data = await r.json() + + # Merge with existing token data (preserve refresh_token if not provided) + if "refresh_token" not in new_token_data: + new_token_data["refresh_token"] = token_data[ + "refresh_token" + ] + + # Add timestamp for tracking + new_token_data["issued_at"] = datetime.now().timestamp() + + # Calculate expires_at if we have expires_in + if ( + "expires_in" in new_token_data + and "expires_at" not in new_token_data + ): + new_token_data["expires_at"] = ( + datetime.now().timestamp() + + new_token_data["expires_in"] + ) + + log.debug(f"Token refresh successful for provider {provider}") + return new_token_data + else: + error_text = await r.text() + log.error( + f"Token refresh failed for provider {provider}: {r.status} - {error_text}" + ) + return None + + except Exception as e: + log.error(f"Exception during token refresh for provider {provider}: {e}") + return None def get_user_role(self, user, user_data): user_count = Users.get_num_users() @@ -624,33 +806,42 @@ class OAuthManager: secure=WEBUI_AUTH_COOKIE_SECURE, ) - if ENABLE_OAUTH_SIGNUP.value: - if ENABLE_OAUTH_SESSION_TOKENS_COOKIES: - oauth_id_token = token.get("id_token") - response.set_cookie( - key="oauth_id_token", - value=oauth_id_token, - httponly=True, - samesite=WEBUI_AUTH_COOKIE_SAME_SITE, - secure=WEBUI_AUTH_COOKIE_SECURE, - ) + # Legacy cookies for compatibility with older frontend versions + if ENABLE_OAUTH_ID_TOKEN_COOKIE: + response.set_cookie( + key="oauth_id_token", + value=token.get("id_token"), + httponly=True, + samesite=WEBUI_AUTH_COOKIE_SAME_SITE, + secure=WEBUI_AUTH_COOKIE_SECURE, + ) - oauth_access_token = token.get("access_token") - response.set_cookie( - key="oauth_access_token", - value=oauth_access_token, - httponly=True, - samesite=WEBUI_AUTH_COOKIE_SAME_SITE, - secure=WEBUI_AUTH_COOKIE_SECURE, - ) + try: + # Add timestamp for tracking + token["issued_at"] = datetime.now().timestamp() - oauth_refresh_token = token.get("refresh_token") - response.set_cookie( - key="oauth_refresh_token", - value=oauth_refresh_token, - httponly=True, - samesite=WEBUI_AUTH_COOKIE_SAME_SITE, - secure=WEBUI_AUTH_COOKIE_SECURE, - ) + # Calculate expires_at if we have expires_in + if "expires_in" in token and "expires_at" not in token: + token["expires_at"] = datetime.now().timestamp() + token["expires_in"] + + session_id = await OAuthSessions.create_session( + user_id=user.id, + provider=provider, + token=token, + ) + + response.set_cookie( + key="oauth_session_id", + value=session_id, + httponly=True, + samesite=WEBUI_AUTH_COOKIE_SAME_SITE, + secure=WEBUI_AUTH_COOKIE_SECURE, + ) + + log.info( + f"Stored OAuth session server-side for user {user.id}, provider {provider}" + ) + except Exception as e: + log.error(f"Failed to store OAuth session server-side: {e}") return response diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index d3ea432019..f0e889d023 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -129,6 +129,21 @@ async def get_tools( headers["Authorization"] = ( f"Bearer {request.state.token.credentials}" ) + elif auth_type == "oauth": + oauth_token = None + try: + oauth_token = ( + await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) elif auth_type == "request_headers": headers.update(dict(request.headers)) diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddServerModal.svelte index 6fad62bc15..8951696c74 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddServerModal.svelte @@ -287,6 +287,7 @@ {#if !direct} + {/if} @@ -305,6 +306,12 @@ > {$i18n.t('Forwards system user session credentials to authenticate')} + {:else if auth_type === 'oauth'} +
+ {$i18n.t('Forwards user OAuth access token to authenticate')} +
{:else if auth_type === 'request_headers'}