mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: server-side OAuth token management system
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
parent
6d38ac41b6
commit
217f4daef0
10 changed files with 627 additions and 92 deletions
|
|
@ -465,8 +465,17 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
|
|||
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_SESSION_TOKENS_COOKIES = (
|
||||
os.environ.get("ENABLE_OAUTH_SESSION_TOKENS_COOKIES", "True").lower() == "true"
|
||||
####################################
|
||||
# OAUTH Configuration
|
||||
####################################
|
||||
|
||||
|
||||
ENABLE_OAUTH_ID_TOKEN_COOKIE = (
|
||||
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
|
||||
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -592,6 +592,7 @@ app = FastAPI(
|
|||
)
|
||||
|
||||
oauth_manager = OAuthManager(app)
|
||||
app.state.oauth_manager = oauth_manager
|
||||
|
||||
app.state.instance_id = None
|
||||
app.state.config = AppConfig(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
"""Add oauth_session table
|
||||
|
||||
Revision ID: 38d63c18f30f
|
||||
Revises: 3af16a1c9fb6
|
||||
Create Date: 2025-09-08 14:19:59.583921
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "38d63c18f30f"
|
||||
down_revision: Union[str, None] = "3af16a1c9fb6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create oauth_session table
|
||||
op.create_table(
|
||||
"oauth_session",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("provider", sa.Text(), nullable=False),
|
||||
sa.Column("token", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
# Create indexes for better performance
|
||||
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
|
||||
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
|
||||
op.create_index(
|
||||
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes first
|
||||
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
|
||||
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
|
||||
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
|
||||
|
||||
# Drop the table
|
||||
op.drop_table("oauth_session")
|
||||
247
backend/open_webui/models/oauth_sessions.py
Normal file
247
backend/open_webui/models/oauth_sessions.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
import time
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, Index
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# DB MODEL
|
||||
####################
|
||||
|
||||
|
||||
class OAuthSession(Base):
|
||||
__tablename__ = "oauth_session"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
provider = Column(Text, nullable=False)
|
||||
token = Column(
|
||||
Text, nullable=False
|
||||
) # JSON with access_token, id_token, refresh_token
|
||||
expires_at = Column(BigInteger, nullable=False)
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
# Add indexes for better performance
|
||||
__table_args__ = (
|
||||
Index("idx_oauth_session_user_id", "user_id"),
|
||||
Index("idx_oauth_session_expires_at", "expires_at"),
|
||||
Index("idx_oauth_session_user_provider", "user_id", "provider"),
|
||||
)
|
||||
|
||||
|
||||
class OAuthSessionModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
token: dict
|
||||
expires_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class OAuthSessionResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
class OAuthSessionTable:
|
||||
def __init__(self):
|
||||
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
if not self.encryption_key:
|
||||
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
|
||||
|
||||
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
||||
if len(self.encryption_key) != 44:
|
||||
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
|
||||
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
|
||||
else:
|
||||
self.encryption_key = self.encryption_key.encode()
|
||||
|
||||
try:
|
||||
self.fernet = Fernet(self.encryption_key)
|
||||
except Exception as e:
|
||||
log.error(f"Error initializing Fernet with provided key: {e}")
|
||||
raise
|
||||
|
||||
def _encrypt_token(self, token) -> str:
|
||||
"""Encrypt OAuth tokens for storage"""
|
||||
try:
|
||||
token_json = json.dumps(token)
|
||||
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
||||
return encrypted
|
||||
except Exception as e:
|
||||
log.error(f"Error encrypting tokens: {e}")
|
||||
raise
|
||||
|
||||
def _decrypt_token(self, token: str):
|
||||
"""Decrypt OAuth tokens from storage"""
|
||||
try:
|
||||
decrypted = self.fernet.decrypt(token.encode()).decode()
|
||||
return json.loads(decrypted)
|
||||
except Exception as e:
|
||||
log.error(f"Error decrypting tokens: {e}")
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
token: dict,
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Create a new OAuth session"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
current_time = int(time.time())
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
result = OAuthSession(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
if result:
|
||||
result.token = token # Return decrypted token
|
||||
return OAuthSessionModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error creating OAuth session: {e}")
|
||||
return None
|
||||
|
||||
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
||||
if session:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
return OAuthSessionModel.model_validate(session)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
return None
|
||||
|
||||
def get_session_by_id_and_user_id(
|
||||
self, session_id: str, user_id: str
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID and user ID"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
session = (
|
||||
db.query(OAuthSession)
|
||||
.filter_by(id=session_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
if session:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
return OAuthSessionModel.model_validate(session)
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
return None
|
||||
|
||||
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
|
||||
"""Get all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
|
||||
|
||||
|
||||
results = []
|
||||
for session in sessions:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
results.append(OAuthSessionModel.model_validate(session))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth sessions by user ID: {e}")
|
||||
return []
|
||||
|
||||
def update_session_by_id(
|
||||
self, session_id: str, token: dict
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Update OAuth session tokens"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
current_time = int(time.time())
|
||||
|
||||
db.query(OAuthSession).filter_by(id=session_id).update(
|
||||
{
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
||||
|
||||
if session:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
return OAuthSessionModel.model_validate(session)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error updating OAuth session tokens: {e}")
|
||||
return None
|
||||
|
||||
def delete_session_by_id(self, session_id: str) -> bool:
|
||||
"""Delete an OAuth session"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(OAuthSession).filter_by(id=session_id).delete()
|
||||
db.commit()
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth session: {e}")
|
||||
return False
|
||||
|
||||
def delete_sessions_by_user_id(self, user_id: str) -> bool:
|
||||
"""Delete all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||
return False
|
||||
|
||||
|
||||
OAuthSessions = OAuthSessionTable()
|
||||
|
|
@ -19,6 +19,7 @@ from open_webui.models.auths import (
|
|||
)
|
||||
from open_webui.models.users import Users, UpdateProfileForm
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.oauth_sessions import OAuthSessions
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
|
|
@ -28,7 +29,6 @@ from open_webui.env import (
|
|||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
ENABLE_INITIAL_ADMIN_SIGNUP,
|
||||
SRC_LOG_LEVELS,
|
||||
|
|
@ -678,24 +678,27 @@ async def signout(request: Request, response: Response):
|
|||
response.delete_cookie("token")
|
||||
response.delete_cookie("oui-session")
|
||||
|
||||
if ENABLE_OAUTH_SIGNUP.value:
|
||||
# TODO: update this to use oauth_session_tokens in User Object
|
||||
oauth_id_token = request.cookies.get("oauth_id_token")
|
||||
oauth_session_id = request.cookies.get("oauth_session_id")
|
||||
if oauth_session_id:
|
||||
response.delete_cookie("oauth_session_id")
|
||||
|
||||
if oauth_id_token and OPENID_PROVIDER_URL.value:
|
||||
session = OAuthSessions.get_session_by_id(oauth_session_id)
|
||||
oauth_server_metadata_url = (
|
||||
request.app.state.oauth_manager.get_server_metadata_url(session.provider)
|
||||
if session
|
||||
else None
|
||||
) or OPENID_PROVIDER_URL.value
|
||||
|
||||
if session and oauth_server_metadata_url:
|
||||
oauth_id_token = session.token.get("id_token")
|
||||
try:
|
||||
async with ClientSession(trust_env=True) as session:
|
||||
async with session.get(OPENID_PROVIDER_URL.value) as r:
|
||||
async with session.get(oauth_server_metadata_url) as r:
|
||||
if r.status == 200:
|
||||
openid_data = await r.json()
|
||||
logout_url = openid_data.get("end_session_endpoint")
|
||||
|
||||
if logout_url:
|
||||
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
response.delete_cookie("oauth_access_token")
|
||||
response.delete_cookie("oauth_refresh_token")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
|
|
@ -710,15 +713,14 @@ async def signout(request: Request, response: Response):
|
|||
headers=response.headers,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=r.status,
|
||||
detail="Failed to fetch OpenID configuration",
|
||||
)
|
||||
raise Exception("Failed to fetch OpenID configuration")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"OpenID signout error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to sign out from the OpenID provider.",
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
|
||||
|
|
|
|||
|
|
@ -261,61 +261,63 @@ def get_current_user(
|
|||
return user
|
||||
|
||||
# auth by jwt token
|
||||
try:
|
||||
data = decode_token(token)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
try:
|
||||
try:
|
||||
data = decode_token(token)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
detail="Invalid token",
|
||||
)
|
||||
else:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
trusted_email = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
|
||||
).lower()
|
||||
if trusted_email and user.email != trusted_email:
|
||||
# Delete the token cookie
|
||||
response.delete_cookie("token")
|
||||
# Delete OAuth token if present
|
||||
|
||||
if request.cookies.get("oauth_id_token"):
|
||||
response.delete_cookie("oauth_id_token")
|
||||
if request.cookies.get("oauth_access_token"):
|
||||
response.delete_cookie("oauth_access_token")
|
||||
if request.cookies.get("oauth_refresh_token"):
|
||||
response.delete_cookie("oauth_refresh_token")
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
trusted_email = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
|
||||
).lower()
|
||||
if trusted_email and user.email != trusted_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User mismatch. Please sign in again.",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User mismatch. Please sign in again.",
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "jwt")
|
||||
|
||||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
Users.update_user_last_active_by_id, user.id
|
||||
)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
except Exception as e:
|
||||
# Delete the token cookie
|
||||
if request.cookies.get("token"):
|
||||
response.delete_cookie("token")
|
||||
# Delete OAuth session if present
|
||||
if request.cookies.get("oauth_session_id"):
|
||||
response.delete_cookie("oauth_session_id")
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "jwt")
|
||||
|
||||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def get_current_user_by_api_key(api_key: str):
|
||||
|
|
|
|||
|
|
@ -815,6 +815,14 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id, request.cookies.get("oauth_session_id", None)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
|
|
@ -822,6 +830,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
"__oauth_token__": oauth_token,
|
||||
}
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@ import mimetypes
|
|||
import sys
|
||||
import uuid
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import re
|
||||
import fnmatch
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
|
|
@ -17,8 +19,12 @@ from fastapi import (
|
|||
)
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.oauth_sessions import OAuthSessions
|
||||
from open_webui.models.users import Users
|
||||
|
||||
|
||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
||||
from open_webui.config import (
|
||||
DEFAULT_USER_ROLE,
|
||||
|
|
@ -49,7 +55,7 @@ from open_webui.env import (
|
|||
WEBUI_NAME,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
|
||||
ENABLE_OAUTH_ID_TOKEN_COOKIE,
|
||||
)
|
||||
from open_webui.utils.misc import parse_duration
|
||||
from open_webui.utils.auth import get_password_hash, create_token
|
||||
|
|
@ -131,11 +137,187 @@ class OAuthManager:
|
|||
def __init__(self, app):
|
||||
self.oauth = OAuth()
|
||||
self.app = app
|
||||
|
||||
self._clients = {}
|
||||
for _, provider_config in OAUTH_PROVIDERS.items():
|
||||
provider_config["register"](self.oauth)
|
||||
|
||||
def get_client(self, provider_name):
|
||||
return self.oauth.create_client(provider_name)
|
||||
if provider_name not in self._clients:
|
||||
self._clients[provider_name] = self.oauth.create_client(provider_name)
|
||||
return self._clients[provider_name]
|
||||
|
||||
def get_server_metadata_url(self, provider_name):
|
||||
if provider_name in self._clients:
|
||||
client = self._clients[provider_name]
|
||||
return (
|
||||
client.server_metadata_url
|
||||
if hasattr(client, "server_metadata_url")
|
||||
else None
|
||||
)
|
||||
return None
|
||||
|
||||
def get_oauth_token(
|
||||
self, user_id: str, session_id: str, force_refresh: bool = False
|
||||
):
|
||||
"""
|
||||
Get a valid OAuth token for the user, automatically refreshing if needed.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
provider: Optional provider name. If None, gets the most recent session.
|
||||
force_refresh: Force token refresh even if current token appears valid
|
||||
|
||||
Returns:
|
||||
dict: OAuth token data with access_token, or None if no valid token available
|
||||
"""
|
||||
try:
|
||||
# Get the OAuth session
|
||||
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
|
||||
if not session:
|
||||
log.warning(
|
||||
f"No OAuth session found for user {user_id}, session {session_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
if force_refresh or datetime.now() + timedelta(
|
||||
minutes=5
|
||||
) >= datetime.fromtimestamp(session.expires_at):
|
||||
log.debug(
|
||||
f"Token refresh needed for user {user_id}, provider {session.provider}"
|
||||
)
|
||||
refreshed_token = self._refresh_token(session)
|
||||
if refreshed_token:
|
||||
return refreshed_token
|
||||
else:
|
||||
log.warning(
|
||||
f"Token refresh failed for user {user_id}, provider {session.provider}"
|
||||
)
|
||||
return None
|
||||
return session.token
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _refresh_token(self, session) -> dict:
|
||||
"""
|
||||
Refresh an OAuth token if needed, with concurrency protection.
|
||||
|
||||
Args:
|
||||
session: The OAuth session object
|
||||
|
||||
Returns:
|
||||
dict: Refreshed token data, or None if refresh failed
|
||||
"""
|
||||
try:
|
||||
# Perform the actual refresh
|
||||
refreshed_token = await self._perform_token_refresh(session)
|
||||
|
||||
if refreshed_token:
|
||||
# Update the session with new token data
|
||||
session = OAuthSessions.update_session_by_id(
|
||||
session.id, refreshed_token
|
||||
)
|
||||
log.info(f"Successfully refreshed token for session {session.id}")
|
||||
return session.token
|
||||
else:
|
||||
log.error(f"Failed to refresh token for session {session.id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error refreshing token for session {session.id}: {e}")
|
||||
return None
|
||||
|
||||
async def _perform_token_refresh(self, session) -> dict:
|
||||
"""
|
||||
Perform the actual OAuth token refresh.
|
||||
|
||||
Args:
|
||||
session: The OAuth session object
|
||||
|
||||
Returns:
|
||||
dict: New token data, or None if refresh failed
|
||||
"""
|
||||
provider = session.provider
|
||||
token_data = session.token
|
||||
|
||||
if not token_data.get("refresh_token"):
|
||||
log.warning(f"No refresh token available for session {session.id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
client = self.get_client(provider)
|
||||
if not client:
|
||||
log.error(f"No OAuth client found for provider {provider}")
|
||||
return None
|
||||
|
||||
token_endpoint = None
|
||||
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||
async with session_http.get(client.gserver_metadata_url) as r:
|
||||
if r.status == 200:
|
||||
openid_data = await r.json()
|
||||
token_endpoint = openid_data.get("token_endpoint")
|
||||
else:
|
||||
log.error(
|
||||
f"Failed to fetch OpenID configuration for provider {provider}"
|
||||
)
|
||||
if not token_endpoint:
|
||||
log.error(f"No token endpoint found for provider {provider}")
|
||||
return None
|
||||
|
||||
# Prepare refresh request
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": client.client_id,
|
||||
}
|
||||
# Add client_secret if available (some providers require it)
|
||||
if hasattr(client, "client_secret") and client.client_secret:
|
||||
refresh_data["client_secret"] = client.client_secret
|
||||
|
||||
# Make refresh request
|
||||
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||
async with session_http.post(
|
||||
token_endpoint,
|
||||
data=refresh_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status == 200:
|
||||
new_token_data = await r.json()
|
||||
|
||||
# Merge with existing token data (preserve refresh_token if not provided)
|
||||
if "refresh_token" not in new_token_data:
|
||||
new_token_data["refresh_token"] = token_data[
|
||||
"refresh_token"
|
||||
]
|
||||
|
||||
# Add timestamp for tracking
|
||||
new_token_data["issued_at"] = datetime.now().timestamp()
|
||||
|
||||
# Calculate expires_at if we have expires_in
|
||||
if (
|
||||
"expires_in" in new_token_data
|
||||
and "expires_at" not in new_token_data
|
||||
):
|
||||
new_token_data["expires_at"] = (
|
||||
datetime.now().timestamp()
|
||||
+ new_token_data["expires_in"]
|
||||
)
|
||||
|
||||
log.debug(f"Token refresh successful for provider {provider}")
|
||||
return new_token_data
|
||||
else:
|
||||
error_text = await r.text()
|
||||
log.error(
|
||||
f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Exception during token refresh for provider {provider}: {e}")
|
||||
return None
|
||||
|
||||
def get_user_role(self, user, user_data):
|
||||
user_count = Users.get_num_users()
|
||||
|
|
@ -624,33 +806,42 @@ class OAuthManager:
|
|||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
if ENABLE_OAUTH_SIGNUP.value:
|
||||
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
||||
oauth_id_token = token.get("id_token")
|
||||
response.set_cookie(
|
||||
key="oauth_id_token",
|
||||
value=oauth_id_token,
|
||||
httponly=True,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Legacy cookies for compatibility with older frontend versions
|
||||
if ENABLE_OAUTH_ID_TOKEN_COOKIE:
|
||||
response.set_cookie(
|
||||
key="oauth_id_token",
|
||||
value=token.get("id_token"),
|
||||
httponly=True,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
oauth_access_token = token.get("access_token")
|
||||
response.set_cookie(
|
||||
key="oauth_access_token",
|
||||
value=oauth_access_token,
|
||||
httponly=True,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
try:
|
||||
# Add timestamp for tracking
|
||||
token["issued_at"] = datetime.now().timestamp()
|
||||
|
||||
oauth_refresh_token = token.get("refresh_token")
|
||||
response.set_cookie(
|
||||
key="oauth_refresh_token",
|
||||
value=oauth_refresh_token,
|
||||
httponly=True,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Calculate expires_at if we have expires_in
|
||||
if "expires_in" in token and "expires_at" not in token:
|
||||
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
|
||||
|
||||
session_id = await OAuthSessions.create_session(
|
||||
user_id=user.id,
|
||||
provider=provider,
|
||||
token=token,
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key="oauth_session_id",
|
||||
value=session_id,
|
||||
httponly=True,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Stored OAuth session server-side for user {user.id}, provider {provider}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to store OAuth session server-side: {e}")
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -129,6 +129,21 @@ async def get_tools(
|
|||
headers["Authorization"] = (
|
||||
f"Bearer {request.state.token.credentials}"
|
||||
)
|
||||
elif auth_type == "oauth":
|
||||
oauth_token = None
|
||||
try:
|
||||
oauth_token = (
|
||||
await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {oauth_token.get('access_token', '')}"
|
||||
)
|
||||
elif auth_type == "request_headers":
|
||||
headers.update(dict(request.headers))
|
||||
|
||||
|
|
|
|||
|
|
@ -287,6 +287,7 @@
|
|||
<option value="session">{$i18n.t('Session')}</option>
|
||||
|
||||
{#if !direct}
|
||||
<option value="oauth">{$i18n.t('OAuth')}</option>
|
||||
<option value="request_headers">{$i18n.t('Request Headers')}</option>
|
||||
{/if}
|
||||
</select>
|
||||
|
|
@ -305,6 +306,12 @@
|
|||
>
|
||||
{$i18n.t('Forwards system user session credentials to authenticate')}
|
||||
</div>
|
||||
{:else if auth_type === 'oauth'}
|
||||
<div
|
||||
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
|
||||
>
|
||||
{$i18n.t('Forwards user OAuth access token to authenticate')}
|
||||
</div>
|
||||
{:else if auth_type === 'request_headers'}
|
||||
<div
|
||||
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
|
||||
|
|
|
|||
Loading…
Reference in a new issue