open-webui/backend/open_webui/models/oauth_sessions.py

267 lines
8.7 KiB
Python
Raw Normal View History

import time
import logging
import uuid
from typing import Optional, List
import base64
import hashlib
import json
from cryptography.fernet import Fernet
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Index
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# DB MODEL
####################
class OAuthSession(Base):
__tablename__ = "oauth_session"
id = Column(Text, primary_key=True)
user_id = Column(Text, nullable=False)
provider = Column(Text, nullable=False)
token = Column(
Text, nullable=False
) # JSON with access_token, id_token, refresh_token
expires_at = Column(BigInteger, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
# Add indexes for better performance
__table_args__ = (
Index("idx_oauth_session_user_id", "user_id"),
Index("idx_oauth_session_expires_at", "expires_at"),
Index("idx_oauth_session_user_provider", "user_id", "provider"),
)
class OAuthSessionModel(BaseModel):
id: str
user_id: str
provider: str
token: dict
expires_at: int # timestamp in epoch
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class OAuthSessionResponse(BaseModel):
id: str
user_id: str
provider: str
expires_at: int
class OAuthSessionTable:
def __init__(self):
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
if not self.encryption_key:
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
if len(self.encryption_key) != 44:
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
else:
self.encryption_key = self.encryption_key.encode()
try:
self.fernet = Fernet(self.encryption_key)
except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}")
raise
def _encrypt_token(self, token) -> str:
"""Encrypt OAuth tokens for storage"""
try:
token_json = json.dumps(token)
encrypted = self.fernet.encrypt(token_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
raise
def _decrypt_token(self, token: str):
"""Decrypt OAuth tokens from storage"""
try:
decrypted = self.fernet.decrypt(token.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {e}")
raise
def create_session(
self,
user_id: str,
provider: str,
token: dict,
) -> Optional[OAuthSessionModel]:
"""Create a new OAuth session"""
try:
with get_db() as db:
current_time = int(time.time())
id = str(uuid.uuid4())
result = OAuthSession(
**{
"id": id,
"user_id": user_id,
"provider": provider,
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"created_at": current_time,
"updated_at": current_time,
}
)
db.add(result)
db.commit()
db.refresh(result)
if result:
result.token = token # Return decrypted token
return OAuthSessionModel.model_validate(result)
else:
return None
except Exception as e:
log.error(f"Error creating OAuth session: {e}")
return None
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID"""
try:
with get_db() as db:
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
2025-09-08 14:09:01 +00:00
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_id_and_user_id(
self, session_id: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
2025-09-08 14:09:01 +00:00
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
2025-09-25 06:49:16 +00:00
def get_session_by_provider_and_user_id(
self, provider: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by provider and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(provider=provider, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}")
return None
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
with get_db() as db:
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
results = []
for session in sessions:
session.token = self._decrypt_token(session.token)
results.append(OAuthSessionModel.model_validate(session))
return results
2025-09-08 14:09:01 +00:00
except Exception as e:
log.error(f"Error getting OAuth sessions by user ID: {e}")
return []
def update_session_by_id(
self, session_id: str, token: dict
) -> Optional[OAuthSessionModel]:
"""Update OAuth session tokens"""
try:
with get_db() as db:
current_time = int(time.time())
db.query(OAuthSession).filter_by(id=session_id).update(
{
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"updated_at": current_time,
}
)
db.commit()
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
2025-09-08 14:09:01 +00:00
return None
except Exception as e:
log.error(f"Error updating OAuth session tokens: {e}")
return None
def delete_session_by_id(self, session_id: str) -> bool:
"""Delete an OAuth session"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(id=session_id).delete()
db.commit()
return result > 0
except Exception as e:
log.error(f"Error deleting OAuth session: {e}")
return False
def delete_sessions_by_user_id(self, user_id: str) -> bool:
"""Delete all OAuth sessions for a user"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False
OAuthSessions = OAuthSessionTable()