mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
feat: server-side OAuth token management system
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
parent
6d38ac41b6
commit
217f4daef0
10 changed files with 627 additions and 92 deletions
|
|
@ -465,8 +465,17 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
|
||||||
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
|
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_OAUTH_SESSION_TOKENS_COOKIES = (
|
####################################
|
||||||
os.environ.get("ENABLE_OAUTH_SESSION_TOKENS_COOKIES", "True").lower() == "true"
|
# OAUTH Configuration
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
ENABLE_OAUTH_ID_TOKEN_COOKIE = (
|
||||||
|
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
|
||||||
|
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -592,6 +592,7 @@ app = FastAPI(
|
||||||
)
|
)
|
||||||
|
|
||||||
oauth_manager = OAuthManager(app)
|
oauth_manager = OAuthManager(app)
|
||||||
|
app.state.oauth_manager = oauth_manager
|
||||||
|
|
||||||
app.state.instance_id = None
|
app.state.instance_id = None
|
||||||
app.state.config = AppConfig(
|
app.state.config = AppConfig(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""Add oauth_session table
|
||||||
|
|
||||||
|
Revision ID: 38d63c18f30f
|
||||||
|
Revises: 3af16a1c9fb6
|
||||||
|
Create Date: 2025-09-08 14:19:59.583921
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "38d63c18f30f"
|
||||||
|
down_revision: Union[str, None] = "3af16a1c9fb6"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create oauth_session table
|
||||||
|
op.create_table(
|
||||||
|
"oauth_session",
|
||||||
|
sa.Column("id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("provider", sa.Text(), nullable=False),
|
||||||
|
sa.Column("token", sa.Text(), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes for better performance
|
||||||
|
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
|
||||||
|
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
|
||||||
|
op.create_index(
|
||||||
|
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop indexes first
|
||||||
|
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
|
||||||
|
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
|
||||||
|
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
|
||||||
|
|
||||||
|
# Drop the table
|
||||||
|
op.drop_table("oauth_session")
|
||||||
247
backend/open_webui/models/oauth_sessions.py
Normal file
247
backend/open_webui/models/oauth_sessions.py
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, List
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
from open_webui.internal.db import Base, get_db
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, String, Text, Index
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
####################
|
||||||
|
# DB MODEL
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSession(Base):
|
||||||
|
__tablename__ = "oauth_session"
|
||||||
|
|
||||||
|
id = Column(Text, primary_key=True)
|
||||||
|
user_id = Column(Text, nullable=False)
|
||||||
|
provider = Column(Text, nullable=False)
|
||||||
|
token = Column(
|
||||||
|
Text, nullable=False
|
||||||
|
) # JSON with access_token, id_token, refresh_token
|
||||||
|
expires_at = Column(BigInteger, nullable=False)
|
||||||
|
created_at = Column(BigInteger, nullable=False)
|
||||||
|
updated_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
|
# Add indexes for better performance
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_oauth_session_user_id", "user_id"),
|
||||||
|
Index("idx_oauth_session_expires_at", "expires_at"),
|
||||||
|
Index("idx_oauth_session_user_provider", "user_id", "provider"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSessionModel(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
provider: str
|
||||||
|
token: dict
|
||||||
|
expires_at: int # timestamp in epoch
|
||||||
|
created_at: int # timestamp in epoch
|
||||||
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Forms
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSessionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
provider: str
|
||||||
|
expires_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSessionTable:
|
||||||
|
def __init__(self):
|
||||||
|
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||||
|
if not self.encryption_key:
|
||||||
|
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
|
||||||
|
|
||||||
|
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
||||||
|
if len(self.encryption_key) != 44:
|
||||||
|
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
|
||||||
|
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
|
||||||
|
else:
|
||||||
|
self.encryption_key = self.encryption_key.encode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.fernet = Fernet(self.encryption_key)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error initializing Fernet with provided key: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _encrypt_token(self, token) -> str:
|
||||||
|
"""Encrypt OAuth tokens for storage"""
|
||||||
|
try:
|
||||||
|
token_json = json.dumps(token)
|
||||||
|
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
||||||
|
return encrypted
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error encrypting tokens: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _decrypt_token(self, token: str):
|
||||||
|
"""Decrypt OAuth tokens from storage"""
|
||||||
|
try:
|
||||||
|
decrypted = self.fernet.decrypt(token.encode()).decode()
|
||||||
|
return json.loads(decrypted)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error decrypting tokens: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
token: dict,
|
||||||
|
) -> Optional[OAuthSessionModel]:
|
||||||
|
"""Create a new OAuth session"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
current_time = int(time.time())
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
result = OAuthSession(
|
||||||
|
**{
|
||||||
|
"id": id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"provider": provider,
|
||||||
|
"token": self._encrypt_token(token),
|
||||||
|
"expires_at": token.get("expires_at"),
|
||||||
|
"created_at": current_time,
|
||||||
|
"updated_at": current_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(result)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(result)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
result.token = token # Return decrypted token
|
||||||
|
return OAuthSessionModel.model_validate(result)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error creating OAuth session: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
|
||||||
|
"""Get OAuth session by ID"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
||||||
|
if session:
|
||||||
|
session.token = self._decrypt_token(session.token)
|
||||||
|
return OAuthSessionModel.model_validate(session)
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth session by ID: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_session_by_id_and_user_id(
|
||||||
|
self, session_id: str, user_id: str
|
||||||
|
) -> Optional[OAuthSessionModel]:
|
||||||
|
"""Get OAuth session by ID and user ID"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
session = (
|
||||||
|
db.query(OAuthSession)
|
||||||
|
.filter_by(id=session_id, user_id=user_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if session:
|
||||||
|
session.token = self._decrypt_token(session.token)
|
||||||
|
return OAuthSessionModel.model_validate(session)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth session by ID: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
|
||||||
|
"""Get all OAuth sessions for a user"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
|
||||||
|
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for session in sessions:
|
||||||
|
session.token = self._decrypt_token(session.token)
|
||||||
|
results.append(OAuthSessionModel.model_validate(session))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth sessions by user ID: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def update_session_by_id(
|
||||||
|
self, session_id: str, token: dict
|
||||||
|
) -> Optional[OAuthSessionModel]:
|
||||||
|
"""Update OAuth session tokens"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
db.query(OAuthSession).filter_by(id=session_id).update(
|
||||||
|
{
|
||||||
|
"token": self._encrypt_token(token),
|
||||||
|
"expires_at": token.get("expires_at"),
|
||||||
|
"updated_at": current_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
||||||
|
|
||||||
|
if session:
|
||||||
|
session.token = self._decrypt_token(session.token)
|
||||||
|
return OAuthSessionModel.model_validate(session)
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error updating OAuth session tokens: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_session_by_id(self, session_id: str) -> bool:
|
||||||
|
"""Delete an OAuth session"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
result = db.query(OAuthSession).filter_by(id=session_id).delete()
|
||||||
|
db.commit()
|
||||||
|
return result > 0
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error deleting OAuth session: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_sessions_by_user_id(self, user_id: str) -> bool:
|
||||||
|
"""Delete all OAuth sessions for a user"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
OAuthSessions = OAuthSessionTable()
|
||||||
|
|
@ -19,6 +19,7 @@ from open_webui.models.auths import (
|
||||||
)
|
)
|
||||||
from open_webui.models.users import Users, UpdateProfileForm
|
from open_webui.models.users import Users, UpdateProfileForm
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
|
@ -28,7 +29,6 @@ from open_webui.env import (
|
||||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
|
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
|
||||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
WEBUI_AUTH_COOKIE_SECURE,
|
WEBUI_AUTH_COOKIE_SECURE,
|
||||||
ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
|
|
||||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||||
ENABLE_INITIAL_ADMIN_SIGNUP,
|
ENABLE_INITIAL_ADMIN_SIGNUP,
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
|
|
@ -678,24 +678,27 @@ async def signout(request: Request, response: Response):
|
||||||
response.delete_cookie("token")
|
response.delete_cookie("token")
|
||||||
response.delete_cookie("oui-session")
|
response.delete_cookie("oui-session")
|
||||||
|
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
oauth_session_id = request.cookies.get("oauth_session_id")
|
||||||
# TODO: update this to use oauth_session_tokens in User Object
|
if oauth_session_id:
|
||||||
oauth_id_token = request.cookies.get("oauth_id_token")
|
response.delete_cookie("oauth_session_id")
|
||||||
|
|
||||||
if oauth_id_token and OPENID_PROVIDER_URL.value:
|
session = OAuthSessions.get_session_by_id(oauth_session_id)
|
||||||
|
oauth_server_metadata_url = (
|
||||||
|
request.app.state.oauth_manager.get_server_metadata_url(session.provider)
|
||||||
|
if session
|
||||||
|
else None
|
||||||
|
) or OPENID_PROVIDER_URL.value
|
||||||
|
|
||||||
|
if session and oauth_server_metadata_url:
|
||||||
|
oauth_id_token = session.token.get("id_token")
|
||||||
try:
|
try:
|
||||||
async with ClientSession(trust_env=True) as session:
|
async with ClientSession(trust_env=True) as session:
|
||||||
async with session.get(OPENID_PROVIDER_URL.value) as r:
|
async with session.get(oauth_server_metadata_url) as r:
|
||||||
if r.status == 200:
|
if r.status == 200:
|
||||||
openid_data = await r.json()
|
openid_data = await r.json()
|
||||||
logout_url = openid_data.get("end_session_endpoint")
|
logout_url = openid_data.get("end_session_endpoint")
|
||||||
|
|
||||||
if logout_url:
|
if logout_url:
|
||||||
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
|
||||||
response.delete_cookie("oauth_id_token")
|
|
||||||
response.delete_cookie("oauth_access_token")
|
|
||||||
response.delete_cookie("oauth_refresh_token")
|
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
content={
|
content={
|
||||||
|
|
@ -710,15 +713,14 @@ async def signout(request: Request, response: Response):
|
||||||
headers=response.headers,
|
headers=response.headers,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise Exception("Failed to fetch OpenID configuration")
|
||||||
status_code=r.status,
|
|
||||||
detail="Failed to fetch OpenID configuration",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"OpenID signout error: {str(e)}")
|
log.error(f"OpenID signout error: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail="Failed to sign out from the OpenID provider.",
|
detail="Failed to sign out from the OpenID provider.",
|
||||||
|
headers=response.headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
|
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
|
||||||
|
|
|
||||||
|
|
@ -261,6 +261,8 @@ def get_current_user(
|
||||||
return user
|
return user
|
||||||
|
|
||||||
# auth by jwt token
|
# auth by jwt token
|
||||||
|
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
data = decode_token(token)
|
data = decode_token(token)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -282,17 +284,6 @@ def get_current_user(
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
|
||||||
).lower()
|
).lower()
|
||||||
if trusted_email and user.email != trusted_email:
|
if trusted_email and user.email != trusted_email:
|
||||||
# Delete the token cookie
|
|
||||||
response.delete_cookie("token")
|
|
||||||
# Delete OAuth token if present
|
|
||||||
|
|
||||||
if request.cookies.get("oauth_id_token"):
|
|
||||||
response.delete_cookie("oauth_id_token")
|
|
||||||
if request.cookies.get("oauth_access_token"):
|
|
||||||
response.delete_cookie("oauth_access_token")
|
|
||||||
if request.cookies.get("oauth_refresh_token"):
|
|
||||||
response.delete_cookie("oauth_refresh_token")
|
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="User mismatch. Please sign in again.",
|
detail="User mismatch. Please sign in again.",
|
||||||
|
|
@ -309,13 +300,24 @@ def get_current_user(
|
||||||
# Refresh the user's last active timestamp asynchronously
|
# Refresh the user's last active timestamp asynchronously
|
||||||
# to prevent blocking the request
|
# to prevent blocking the request
|
||||||
if background_tasks:
|
if background_tasks:
|
||||||
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
|
background_tasks.add_task(
|
||||||
|
Users.update_user_last_active_by_id, user.id
|
||||||
|
)
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Delete the token cookie
|
||||||
|
if request.cookies.get("token"):
|
||||||
|
response.delete_cookie("token")
|
||||||
|
# Delete OAuth session if present
|
||||||
|
if request.cookies.get("oauth_session_id"):
|
||||||
|
response.delete_cookie("oauth_session_id")
|
||||||
|
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_by_api_key(api_key: str):
|
def get_current_user_by_api_key(api_key: str):
|
||||||
|
|
|
||||||
|
|
@ -815,6 +815,14 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
event_emitter = get_event_emitter(metadata)
|
event_emitter = get_event_emitter(metadata)
|
||||||
event_call = get_event_call(metadata)
|
event_call = get_event_call(metadata)
|
||||||
|
|
||||||
|
oauth_token = None
|
||||||
|
try:
|
||||||
|
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||||
|
user.id, request.cookies.get("oauth_session_id", None)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
|
||||||
extra_params = {
|
extra_params = {
|
||||||
"__event_emitter__": event_emitter,
|
"__event_emitter__": event_emitter,
|
||||||
"__event_call__": event_call,
|
"__event_call__": event_call,
|
||||||
|
|
@ -822,6 +830,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
"__metadata__": metadata,
|
"__metadata__": metadata,
|
||||||
"__request__": request,
|
"__request__": request,
|
||||||
"__model__": model,
|
"__model__": model,
|
||||||
|
"__oauth_token__": oauth_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize events to store additional event to be sent to the client
|
# Initialize events to store additional event to be sent to the client
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,11 @@ import mimetypes
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import time
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from authlib.integrations.starlette_client import OAuth
|
from authlib.integrations.starlette_client import OAuth
|
||||||
|
|
@ -17,8 +19,12 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.auths import Auths
|
from open_webui.models.auths import Auths
|
||||||
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
DEFAULT_USER_ROLE,
|
DEFAULT_USER_ROLE,
|
||||||
|
|
@ -49,7 +55,7 @@ from open_webui.env import (
|
||||||
WEBUI_NAME,
|
WEBUI_NAME,
|
||||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
WEBUI_AUTH_COOKIE_SECURE,
|
WEBUI_AUTH_COOKIE_SECURE,
|
||||||
ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
|
ENABLE_OAUTH_ID_TOKEN_COOKIE,
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import parse_duration
|
from open_webui.utils.misc import parse_duration
|
||||||
from open_webui.utils.auth import get_password_hash, create_token
|
from open_webui.utils.auth import get_password_hash, create_token
|
||||||
|
|
@ -131,11 +137,187 @@ class OAuthManager:
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.oauth = OAuth()
|
self.oauth = OAuth()
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
|
self._clients = {}
|
||||||
for _, provider_config in OAUTH_PROVIDERS.items():
|
for _, provider_config in OAUTH_PROVIDERS.items():
|
||||||
provider_config["register"](self.oauth)
|
provider_config["register"](self.oauth)
|
||||||
|
|
||||||
def get_client(self, provider_name):
|
def get_client(self, provider_name):
|
||||||
return self.oauth.create_client(provider_name)
|
if provider_name not in self._clients:
|
||||||
|
self._clients[provider_name] = self.oauth.create_client(provider_name)
|
||||||
|
return self._clients[provider_name]
|
||||||
|
|
||||||
|
def get_server_metadata_url(self, provider_name):
|
||||||
|
if provider_name in self._clients:
|
||||||
|
client = self._clients[provider_name]
|
||||||
|
return (
|
||||||
|
client.server_metadata_url
|
||||||
|
if hasattr(client, "server_metadata_url")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_oauth_token(
|
||||||
|
self, user_id: str, session_id: str, force_refresh: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a valid OAuth token for the user, automatically refreshing if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
provider: Optional provider name. If None, gets the most recent session.
|
||||||
|
force_refresh: Force token refresh even if current token appears valid
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: OAuth token data with access_token, or None if no valid token available
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get the OAuth session
|
||||||
|
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
log.warning(
|
||||||
|
f"No OAuth session found for user {user_id}, session {session_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if force_refresh or datetime.now() + timedelta(
|
||||||
|
minutes=5
|
||||||
|
) >= datetime.fromtimestamp(session.expires_at):
|
||||||
|
log.debug(
|
||||||
|
f"Token refresh needed for user {user_id}, provider {session.provider}"
|
||||||
|
)
|
||||||
|
refreshed_token = self._refresh_token(session)
|
||||||
|
if refreshed_token:
|
||||||
|
return refreshed_token
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
f"Token refresh failed for user {user_id}, provider {session.provider}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return session.token
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token for user {user_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _refresh_token(self, session) -> dict:
|
||||||
|
"""
|
||||||
|
Refresh an OAuth token if needed, with concurrency protection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: The OAuth session object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Refreshed token data, or None if refresh failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Perform the actual refresh
|
||||||
|
refreshed_token = await self._perform_token_refresh(session)
|
||||||
|
|
||||||
|
if refreshed_token:
|
||||||
|
# Update the session with new token data
|
||||||
|
session = OAuthSessions.update_session_by_id(
|
||||||
|
session.id, refreshed_token
|
||||||
|
)
|
||||||
|
log.info(f"Successfully refreshed token for session {session.id}")
|
||||||
|
return session.token
|
||||||
|
else:
|
||||||
|
log.error(f"Failed to refresh token for session {session.id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error refreshing token for session {session.id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _perform_token_refresh(self, session) -> dict:
|
||||||
|
"""
|
||||||
|
Perform the actual OAuth token refresh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: The OAuth session object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: New token data, or None if refresh failed
|
||||||
|
"""
|
||||||
|
provider = session.provider
|
||||||
|
token_data = session.token
|
||||||
|
|
||||||
|
if not token_data.get("refresh_token"):
|
||||||
|
log.warning(f"No refresh token available for session {session.id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = self.get_client(provider)
|
||||||
|
if not client:
|
||||||
|
log.error(f"No OAuth client found for provider {provider}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
token_endpoint = None
|
||||||
|
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||||
|
async with session_http.get(client.gserver_metadata_url) as r:
|
||||||
|
if r.status == 200:
|
||||||
|
openid_data = await r.json()
|
||||||
|
token_endpoint = openid_data.get("token_endpoint")
|
||||||
|
else:
|
||||||
|
log.error(
|
||||||
|
f"Failed to fetch OpenID configuration for provider {provider}"
|
||||||
|
)
|
||||||
|
if not token_endpoint:
|
||||||
|
log.error(f"No token endpoint found for provider {provider}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Prepare refresh request
|
||||||
|
refresh_data = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": token_data["refresh_token"],
|
||||||
|
"client_id": client.client_id,
|
||||||
|
}
|
||||||
|
# Add client_secret if available (some providers require it)
|
||||||
|
if hasattr(client, "client_secret") and client.client_secret:
|
||||||
|
refresh_data["client_secret"] = client.client_secret
|
||||||
|
|
||||||
|
# Make refresh request
|
||||||
|
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||||
|
async with session_http.post(
|
||||||
|
token_endpoint,
|
||||||
|
data=refresh_data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
|
) as r:
|
||||||
|
if r.status == 200:
|
||||||
|
new_token_data = await r.json()
|
||||||
|
|
||||||
|
# Merge with existing token data (preserve refresh_token if not provided)
|
||||||
|
if "refresh_token" not in new_token_data:
|
||||||
|
new_token_data["refresh_token"] = token_data[
|
||||||
|
"refresh_token"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add timestamp for tracking
|
||||||
|
new_token_data["issued_at"] = datetime.now().timestamp()
|
||||||
|
|
||||||
|
# Calculate expires_at if we have expires_in
|
||||||
|
if (
|
||||||
|
"expires_in" in new_token_data
|
||||||
|
and "expires_at" not in new_token_data
|
||||||
|
):
|
||||||
|
new_token_data["expires_at"] = (
|
||||||
|
datetime.now().timestamp()
|
||||||
|
+ new_token_data["expires_in"]
|
||||||
|
)
|
||||||
|
|
||||||
|
log.debug(f"Token refresh successful for provider {provider}")
|
||||||
|
return new_token_data
|
||||||
|
else:
|
||||||
|
error_text = await r.text()
|
||||||
|
log.error(
|
||||||
|
f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Exception during token refresh for provider {provider}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def get_user_role(self, user, user_data):
|
def get_user_role(self, user, user_data):
|
||||||
user_count = Users.get_num_users()
|
user_count = Users.get_num_users()
|
||||||
|
|
@ -624,33 +806,42 @@ class OAuthManager:
|
||||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
# Legacy cookies for compatibility with older frontend versions
|
||||||
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
if ENABLE_OAUTH_ID_TOKEN_COOKIE:
|
||||||
oauth_id_token = token.get("id_token")
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="oauth_id_token",
|
key="oauth_id_token",
|
||||||
value=oauth_id_token,
|
value=token.get("id_token"),
|
||||||
httponly=True,
|
httponly=True,
|
||||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
oauth_access_token = token.get("access_token")
|
try:
|
||||||
|
# Add timestamp for tracking
|
||||||
|
token["issued_at"] = datetime.now().timestamp()
|
||||||
|
|
||||||
|
# Calculate expires_at if we have expires_in
|
||||||
|
if "expires_in" in token and "expires_at" not in token:
|
||||||
|
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
|
||||||
|
|
||||||
|
session_id = await OAuthSessions.create_session(
|
||||||
|
user_id=user.id,
|
||||||
|
provider=provider,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="oauth_access_token",
|
key="oauth_session_id",
|
||||||
value=oauth_access_token,
|
value=session_id,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
oauth_refresh_token = token.get("refresh_token")
|
log.info(
|
||||||
response.set_cookie(
|
f"Stored OAuth session server-side for user {user.id}, provider {provider}"
|
||||||
key="oauth_refresh_token",
|
|
||||||
value=oauth_refresh_token,
|
|
||||||
httponly=True,
|
|
||||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
||||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to store OAuth session server-side: {e}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,21 @@ async def get_tools(
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
f"Bearer {request.state.token.credentials}"
|
f"Bearer {request.state.token.credentials}"
|
||||||
)
|
)
|
||||||
|
elif auth_type == "oauth":
|
||||||
|
oauth_token = None
|
||||||
|
try:
|
||||||
|
oauth_token = (
|
||||||
|
await request.app.state.oauth_manager.get_oauth_token(
|
||||||
|
user.id,
|
||||||
|
request.cookies.get("oauth_session_id", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
|
||||||
|
headers["Authorization"] = (
|
||||||
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
|
)
|
||||||
elif auth_type == "request_headers":
|
elif auth_type == "request_headers":
|
||||||
headers.update(dict(request.headers))
|
headers.update(dict(request.headers))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -287,6 +287,7 @@
|
||||||
<option value="session">{$i18n.t('Session')}</option>
|
<option value="session">{$i18n.t('Session')}</option>
|
||||||
|
|
||||||
{#if !direct}
|
{#if !direct}
|
||||||
|
<option value="oauth">{$i18n.t('OAuth')}</option>
|
||||||
<option value="request_headers">{$i18n.t('Request Headers')}</option>
|
<option value="request_headers">{$i18n.t('Request Headers')}</option>
|
||||||
{/if}
|
{/if}
|
||||||
</select>
|
</select>
|
||||||
|
|
@ -305,6 +306,12 @@
|
||||||
>
|
>
|
||||||
{$i18n.t('Forwards system user session credentials to authenticate')}
|
{$i18n.t('Forwards system user session credentials to authenticate')}
|
||||||
</div>
|
</div>
|
||||||
|
{:else if auth_type === 'oauth'}
|
||||||
|
<div
|
||||||
|
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
|
||||||
|
>
|
||||||
|
{$i18n.t('Forwards user OAuth access token to authenticate')}
|
||||||
|
</div>
|
||||||
{:else if auth_type === 'request_headers'}
|
{:else if auth_type === 'request_headers'}
|
||||||
<div
|
<div
|
||||||
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
|
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue