feat: server-side OAuth token management system

Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-09-08 18:05:43 +04:00
parent 6d38ac41b6
commit 217f4daef0
10 changed files with 627 additions and 92 deletions

View file

@ -465,8 +465,17 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true" os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
) )
ENABLE_OAUTH_SESSION_TOKENS_COOKIES = ( ####################################
os.environ.get("ENABLE_OAUTH_SESSION_TOKENS_COOKIES", "True").lower() == "true" # OAUTH Configuration
####################################
ENABLE_OAUTH_ID_TOKEN_COOKIE = (
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
)
OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
) )

View file

@ -592,6 +592,7 @@ app = FastAPI(
) )
oauth_manager = OAuthManager(app) oauth_manager = OAuthManager(app)
app.state.oauth_manager = oauth_manager
app.state.instance_id = None app.state.instance_id = None
app.state.config = AppConfig( app.state.config = AppConfig(

View file

@ -0,0 +1,52 @@
"""Add oauth_session table
Revision ID: 38d63c18f30f
Revises: 3af16a1c9fb6
Create Date: 2025-09-08 14:19:59.583921
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "38d63c18f30f"
down_revision: Union[str, None] = "3af16a1c9fb6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create oauth_session table
op.create_table(
"oauth_session",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("provider", sa.Text(), nullable=False),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("expires_at", sa.BigInteger(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
)
# Create indexes for better performance
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
op.create_index(
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
)
def downgrade() -> None:
# Drop indexes first
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
# Drop the table
op.drop_table("oauth_session")

View file

@ -0,0 +1,247 @@
import time
import logging
import uuid
from typing import Optional, List
import base64
import hashlib
import json
from cryptography.fernet import Fernet
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Index
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# DB MODEL
####################
class OAuthSession(Base):
__tablename__ = "oauth_session"
id = Column(Text, primary_key=True)
user_id = Column(Text, nullable=False)
provider = Column(Text, nullable=False)
token = Column(
Text, nullable=False
) # JSON with access_token, id_token, refresh_token
expires_at = Column(BigInteger, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
# Add indexes for better performance
__table_args__ = (
Index("idx_oauth_session_user_id", "user_id"),
Index("idx_oauth_session_expires_at", "expires_at"),
Index("idx_oauth_session_user_provider", "user_id", "provider"),
)
class OAuthSessionModel(BaseModel):
id: str
user_id: str
provider: str
token: dict
expires_at: int # timestamp in epoch
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class OAuthSessionResponse(BaseModel):
id: str
user_id: str
provider: str
expires_at: int
class OAuthSessionTable:
def __init__(self):
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
if not self.encryption_key:
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
if len(self.encryption_key) != 44:
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
else:
self.encryption_key = self.encryption_key.encode()
try:
self.fernet = Fernet(self.encryption_key)
except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}")
raise
def _encrypt_token(self, token) -> str:
"""Encrypt OAuth tokens for storage"""
try:
token_json = json.dumps(token)
encrypted = self.fernet.encrypt(token_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
raise
def _decrypt_token(self, token: str):
"""Decrypt OAuth tokens from storage"""
try:
decrypted = self.fernet.decrypt(token.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {e}")
raise
def create_session(
self,
user_id: str,
provider: str,
token: dict,
) -> Optional[OAuthSessionModel]:
"""Create a new OAuth session"""
try:
with get_db() as db:
current_time = int(time.time())
id = str(uuid.uuid4())
result = OAuthSession(
**{
"id": id,
"user_id": user_id,
"provider": provider,
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"created_at": current_time,
"updated_at": current_time,
}
)
db.add(result)
db.commit()
db.refresh(result)
if result:
result.token = token # Return decrypted token
return OAuthSessionModel.model_validate(result)
else:
return None
except Exception as e:
log.error(f"Error creating OAuth session: {e}")
return None
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID"""
try:
with get_db() as db:
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_id_and_user_id(
self, session_id: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
)
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
with get_db() as db:
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
results = []
for session in sessions:
session.token = self._decrypt_token(session.token)
results.append(OAuthSessionModel.model_validate(session))
return results
except Exception as e:
log.error(f"Error getting OAuth sessions by user ID: {e}")
return []
def update_session_by_id(
self, session_id: str, token: dict
) -> Optional[OAuthSessionModel]:
"""Update OAuth session tokens"""
try:
with get_db() as db:
current_time = int(time.time())
db.query(OAuthSession).filter_by(id=session_id).update(
{
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"updated_at": current_time,
}
)
db.commit()
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error updating OAuth session tokens: {e}")
return None
def delete_session_by_id(self, session_id: str) -> bool:
"""Delete an OAuth session"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(id=session_id).delete()
db.commit()
return result > 0
except Exception as e:
log.error(f"Error deleting OAuth session: {e}")
return False
def delete_sessions_by_user_id(self, user_id: str) -> bool:
"""Delete all OAuth sessions for a user"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False
OAuthSessions = OAuthSessionTable()

View file

@ -19,6 +19,7 @@ from open_webui.models.auths import (
) )
from open_webui.models.users import Users, UpdateProfileForm from open_webui.models.users import Users, UpdateProfileForm
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import ( from open_webui.env import (
@ -28,7 +29,6 @@ from open_webui.env import (
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE, WEBUI_AUTH_COOKIE_SECURE,
ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL, WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
ENABLE_INITIAL_ADMIN_SIGNUP, ENABLE_INITIAL_ADMIN_SIGNUP,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
@ -678,24 +678,27 @@ async def signout(request: Request, response: Response):
response.delete_cookie("token") response.delete_cookie("token")
response.delete_cookie("oui-session") response.delete_cookie("oui-session")
if ENABLE_OAUTH_SIGNUP.value: oauth_session_id = request.cookies.get("oauth_session_id")
# TODO: update this to use oauth_session_tokens in User Object if oauth_session_id:
oauth_id_token = request.cookies.get("oauth_id_token") response.delete_cookie("oauth_session_id")
if oauth_id_token and OPENID_PROVIDER_URL.value: session = OAuthSessions.get_session_by_id(oauth_session_id)
oauth_server_metadata_url = (
request.app.state.oauth_manager.get_server_metadata_url(session.provider)
if session
else None
) or OPENID_PROVIDER_URL.value
if session and oauth_server_metadata_url:
oauth_id_token = session.token.get("id_token")
try: try:
async with ClientSession(trust_env=True) as session: async with ClientSession(trust_env=True) as session:
async with session.get(OPENID_PROVIDER_URL.value) as r: async with session.get(oauth_server_metadata_url) as r:
if r.status == 200: if r.status == 200:
openid_data = await r.json() openid_data = await r.json()
logout_url = openid_data.get("end_session_endpoint") logout_url = openid_data.get("end_session_endpoint")
if logout_url: if logout_url:
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
response.delete_cookie("oauth_id_token")
response.delete_cookie("oauth_access_token")
response.delete_cookie("oauth_refresh_token")
return JSONResponse( return JSONResponse(
status_code=200, status_code=200,
content={ content={
@ -710,15 +713,14 @@ async def signout(request: Request, response: Response):
headers=response.headers, headers=response.headers,
) )
else: else:
raise HTTPException( raise Exception("Failed to fetch OpenID configuration")
status_code=r.status,
detail="Failed to fetch OpenID configuration",
)
except Exception as e: except Exception as e:
log.error(f"OpenID signout error: {str(e)}") log.error(f"OpenID signout error: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="Failed to sign out from the OpenID provider.", detail="Failed to sign out from the OpenID provider.",
headers=response.headers,
) )
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL: if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:

View file

@ -261,6 +261,8 @@ def get_current_user(
return user return user
# auth by jwt token # auth by jwt token
try:
try: try:
data = decode_token(token) data = decode_token(token)
except Exception as e: except Exception as e:
@ -282,17 +284,6 @@ def get_current_user(
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, "" WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
).lower() ).lower()
if trusted_email and user.email != trusted_email: if trusted_email and user.email != trusted_email:
# Delete the token cookie
response.delete_cookie("token")
# Delete OAuth token if present
if request.cookies.get("oauth_id_token"):
response.delete_cookie("oauth_id_token")
if request.cookies.get("oauth_access_token"):
response.delete_cookie("oauth_access_token")
if request.cookies.get("oauth_refresh_token"):
response.delete_cookie("oauth_refresh_token")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="User mismatch. Please sign in again.", detail="User mismatch. Please sign in again.",
@ -309,13 +300,24 @@ def get_current_user(
# Refresh the user's last active timestamp asynchronously # Refresh the user's last active timestamp asynchronously
# to prevent blocking the request # to prevent blocking the request
if background_tasks: if background_tasks:
background_tasks.add_task(Users.update_user_last_active_by_id, user.id) background_tasks.add_task(
Users.update_user_last_active_by_id, user.id
)
return user return user
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
except Exception as e:
# Delete the token cookie
if request.cookies.get("token"):
response.delete_cookie("token")
# Delete OAuth session if present
if request.cookies.get("oauth_session_id"):
response.delete_cookie("oauth_session_id")
raise e
def get_current_user_by_api_key(api_key: str): def get_current_user_by_api_key(api_key: str):

View file

@ -815,6 +815,14 @@ async def process_chat_payload(request, form_data, user, metadata, model):
event_emitter = get_event_emitter(metadata) event_emitter = get_event_emitter(metadata)
event_call = get_event_call(metadata) event_call = get_event_call(metadata)
oauth_token = None
try:
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id, request.cookies.get("oauth_session_id", None)
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
extra_params = { extra_params = {
"__event_emitter__": event_emitter, "__event_emitter__": event_emitter,
"__event_call__": event_call, "__event_call__": event_call,
@ -822,6 +830,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__metadata__": metadata, "__metadata__": metadata,
"__request__": request, "__request__": request,
"__model__": model, "__model__": model,
"__oauth_token__": oauth_token,
} }
# Initialize events to store additional event to be sent to the client # Initialize events to store additional event to be sent to the client

View file

@ -4,9 +4,11 @@ import mimetypes
import sys import sys
import uuid import uuid
import json import json
from datetime import datetime, timedelta
import re import re
import fnmatch import fnmatch
import time
import aiohttp import aiohttp
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
@ -17,8 +19,12 @@ from fastapi import (
) )
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from open_webui.models.auths import Auths from open_webui.models.auths import Auths
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
from open_webui.config import ( from open_webui.config import (
DEFAULT_USER_ROLE, DEFAULT_USER_ROLE,
@ -49,7 +55,7 @@ from open_webui.env import (
WEBUI_NAME, WEBUI_NAME,
WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE, WEBUI_AUTH_COOKIE_SECURE,
ENABLE_OAUTH_SESSION_TOKENS_COOKIES, ENABLE_OAUTH_ID_TOKEN_COOKIE,
) )
from open_webui.utils.misc import parse_duration from open_webui.utils.misc import parse_duration
from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.auth import get_password_hash, create_token
@ -131,11 +137,187 @@ class OAuthManager:
def __init__(self, app): def __init__(self, app):
self.oauth = OAuth() self.oauth = OAuth()
self.app = app self.app = app
self._clients = {}
for _, provider_config in OAUTH_PROVIDERS.items(): for _, provider_config in OAUTH_PROVIDERS.items():
provider_config["register"](self.oauth) provider_config["register"](self.oauth)
def get_client(self, provider_name): def get_client(self, provider_name):
return self.oauth.create_client(provider_name) if provider_name not in self._clients:
self._clients[provider_name] = self.oauth.create_client(provider_name)
return self._clients[provider_name]
def get_server_metadata_url(self, provider_name):
if provider_name in self._clients:
client = self._clients[provider_name]
return (
client.server_metadata_url
if hasattr(client, "server_metadata_url")
else None
)
return None
def get_oauth_token(
self, user_id: str, session_id: str, force_refresh: bool = False
):
"""
Get a valid OAuth token for the user, automatically refreshing if needed.
Args:
user_id: The user ID
provider: Optional provider name. If None, gets the most recent session.
force_refresh: Force token refresh even if current token appears valid
Returns:
dict: OAuth token data with access_token, or None if no valid token available
"""
try:
# Get the OAuth session
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
if not session:
log.warning(
f"No OAuth session found for user {user_id}, session {session_id}"
)
return None
if force_refresh or datetime.now() + timedelta(
minutes=5
) >= datetime.fromtimestamp(session.expires_at):
log.debug(
f"Token refresh needed for user {user_id}, provider {session.provider}"
)
refreshed_token = self._refresh_token(session)
if refreshed_token:
return refreshed_token
else:
log.warning(
f"Token refresh failed for user {user_id}, provider {session.provider}"
)
return None
return session.token
except Exception as e:
log.error(f"Error getting OAuth token for user {user_id}: {e}")
return None
async def _refresh_token(self, session) -> dict:
"""
Refresh an OAuth token if needed, with concurrency protection.
Args:
session: The OAuth session object
Returns:
dict: Refreshed token data, or None if refresh failed
"""
try:
# Perform the actual refresh
refreshed_token = await self._perform_token_refresh(session)
if refreshed_token:
# Update the session with new token data
session = OAuthSessions.update_session_by_id(
session.id, refreshed_token
)
log.info(f"Successfully refreshed token for session {session.id}")
return session.token
else:
log.error(f"Failed to refresh token for session {session.id}")
return None
except Exception as e:
log.error(f"Error refreshing token for session {session.id}: {e}")
return None
async def _perform_token_refresh(self, session) -> dict:
"""
Perform the actual OAuth token refresh.
Args:
session: The OAuth session object
Returns:
dict: New token data, or None if refresh failed
"""
provider = session.provider
token_data = session.token
if not token_data.get("refresh_token"):
log.warning(f"No refresh token available for session {session.id}")
return None
try:
client = self.get_client(provider)
if not client:
log.error(f"No OAuth client found for provider {provider}")
return None
token_endpoint = None
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.get(client.gserver_metadata_url) as r:
if r.status == 200:
openid_data = await r.json()
token_endpoint = openid_data.get("token_endpoint")
else:
log.error(
f"Failed to fetch OpenID configuration for provider {provider}"
)
if not token_endpoint:
log.error(f"No token endpoint found for provider {provider}")
return None
# Prepare refresh request
refresh_data = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": client.client_id,
}
# Add client_secret if available (some providers require it)
if hasattr(client, "client_secret") and client.client_secret:
refresh_data["client_secret"] = client.client_secret
# Make refresh request
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.post(
token_endpoint,
data=refresh_data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status == 200:
new_token_data = await r.json()
# Merge with existing token data (preserve refresh_token if not provided)
if "refresh_token" not in new_token_data:
new_token_data["refresh_token"] = token_data[
"refresh_token"
]
# Add timestamp for tracking
new_token_data["issued_at"] = datetime.now().timestamp()
# Calculate expires_at if we have expires_in
if (
"expires_in" in new_token_data
and "expires_at" not in new_token_data
):
new_token_data["expires_at"] = (
datetime.now().timestamp()
+ new_token_data["expires_in"]
)
log.debug(f"Token refresh successful for provider {provider}")
return new_token_data
else:
error_text = await r.text()
log.error(
f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
)
return None
except Exception as e:
log.error(f"Exception during token refresh for provider {provider}: {e}")
return None
def get_user_role(self, user, user_data): def get_user_role(self, user, user_data):
user_count = Users.get_num_users() user_count = Users.get_num_users()
@ -624,33 +806,42 @@ class OAuthManager:
secure=WEBUI_AUTH_COOKIE_SECURE, secure=WEBUI_AUTH_COOKIE_SECURE,
) )
if ENABLE_OAUTH_SIGNUP.value: # Legacy cookies for compatibility with older frontend versions
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES: if ENABLE_OAUTH_ID_TOKEN_COOKIE:
oauth_id_token = token.get("id_token")
response.set_cookie( response.set_cookie(
key="oauth_id_token", key="oauth_id_token",
value=oauth_id_token, value=token.get("id_token"),
httponly=True, httponly=True,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE, samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE, secure=WEBUI_AUTH_COOKIE_SECURE,
) )
oauth_access_token = token.get("access_token") try:
# Add timestamp for tracking
token["issued_at"] = datetime.now().timestamp()
# Calculate expires_at if we have expires_in
if "expires_in" in token and "expires_at" not in token:
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
session_id = await OAuthSessions.create_session(
user_id=user.id,
provider=provider,
token=token,
)
response.set_cookie( response.set_cookie(
key="oauth_access_token", key="oauth_session_id",
value=oauth_access_token, value=session_id,
httponly=True, httponly=True,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE, samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE, secure=WEBUI_AUTH_COOKIE_SECURE,
) )
oauth_refresh_token = token.get("refresh_token") log.info(
response.set_cookie( f"Stored OAuth session server-side for user {user.id}, provider {provider}"
key="oauth_refresh_token",
value=oauth_refresh_token,
httponly=True,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
) )
except Exception as e:
log.error(f"Failed to store OAuth session server-side: {e}")
return response return response

View file

@ -129,6 +129,21 @@ async def get_tools(
headers["Authorization"] = ( headers["Authorization"] = (
f"Bearer {request.state.token.credentials}" f"Bearer {request.state.token.credentials}"
) )
elif auth_type == "oauth":
oauth_token = None
try:
oauth_token = (
await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
elif auth_type == "request_headers": elif auth_type == "request_headers":
headers.update(dict(request.headers)) headers.update(dict(request.headers))

View file

@ -287,6 +287,7 @@
<option value="session">{$i18n.t('Session')}</option> <option value="session">{$i18n.t('Session')}</option>
{#if !direct} {#if !direct}
<option value="oauth">{$i18n.t('OAuth')}</option>
<option value="request_headers">{$i18n.t('Request Headers')}</option> <option value="request_headers">{$i18n.t('Request Headers')}</option>
{/if} {/if}
</select> </select>
@ -305,6 +306,12 @@
> >
{$i18n.t('Forwards system user session credentials to authenticate')} {$i18n.t('Forwards system user session credentials to authenticate')}
</div> </div>
{:else if auth_type === 'oauth'}
<div
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>
{$i18n.t('Forwards user OAuth access token to authenticate')}
</div>
{:else if auth_type === 'request_headers'} {:else if auth_type === 'request_headers'}
<div <div
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`} class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}