2025-09-08 14:05:43 +00:00
|
|
|
import time
|
|
|
|
|
import logging
|
|
|
|
|
import uuid
|
|
|
|
|
from typing import Optional, List
|
|
|
|
|
import base64
|
|
|
|
|
import hashlib
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
from cryptography.fernet import Fernet
|
|
|
|
|
|
|
|
|
|
from open_webui.internal.db import Base, get_db
|
|
|
|
|
from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
|
from sqlalchemy import BigInteger, Column, String, Text, Index
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
|
|
|
|
|
|
|
|
####################
|
|
|
|
|
# DB MODEL
|
|
|
|
|
####################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OAuthSession(Base):
|
|
|
|
|
__tablename__ = "oauth_session"
|
|
|
|
|
|
|
|
|
|
id = Column(Text, primary_key=True)
|
|
|
|
|
user_id = Column(Text, nullable=False)
|
|
|
|
|
provider = Column(Text, nullable=False)
|
|
|
|
|
token = Column(
|
|
|
|
|
Text, nullable=False
|
|
|
|
|
) # JSON with access_token, id_token, refresh_token
|
|
|
|
|
expires_at = Column(BigInteger, nullable=False)
|
|
|
|
|
created_at = Column(BigInteger, nullable=False)
|
|
|
|
|
updated_at = Column(BigInteger, nullable=False)
|
|
|
|
|
|
|
|
|
|
# Add indexes for better performance
|
|
|
|
|
__table_args__ = (
|
|
|
|
|
Index("idx_oauth_session_user_id", "user_id"),
|
|
|
|
|
Index("idx_oauth_session_expires_at", "expires_at"),
|
|
|
|
|
Index("idx_oauth_session_user_provider", "user_id", "provider"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OAuthSessionModel(BaseModel):
|
|
|
|
|
id: str
|
|
|
|
|
user_id: str
|
|
|
|
|
provider: str
|
|
|
|
|
token: dict
|
|
|
|
|
expires_at: int # timestamp in epoch
|
|
|
|
|
created_at: int # timestamp in epoch
|
|
|
|
|
updated_at: int # timestamp in epoch
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
####################
|
|
|
|
|
# Forms
|
|
|
|
|
####################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OAuthSessionResponse(BaseModel):
|
|
|
|
|
id: str
|
|
|
|
|
user_id: str
|
|
|
|
|
provider: str
|
|
|
|
|
expires_at: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OAuthSessionTable:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
|
|
|
|
if not self.encryption_key:
|
|
|
|
|
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
|
|
|
|
|
|
|
|
|
|
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
|
|
|
|
if len(self.encryption_key) != 44:
|
|
|
|
|
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
|
|
|
|
|
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
|
|
|
|
|
else:
|
|
|
|
|
self.encryption_key = self.encryption_key.encode()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
self.fernet = Fernet(self.encryption_key)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error initializing Fernet with provided key: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def _encrypt_token(self, token) -> str:
|
|
|
|
|
"""Encrypt OAuth tokens for storage"""
|
|
|
|
|
try:
|
|
|
|
|
token_json = json.dumps(token)
|
|
|
|
|
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
|
|
|
|
return encrypted
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error encrypting tokens: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def _decrypt_token(self, token: str):
|
|
|
|
|
"""Decrypt OAuth tokens from storage"""
|
|
|
|
|
try:
|
|
|
|
|
decrypted = self.fernet.decrypt(token.encode()).decode()
|
|
|
|
|
return json.loads(decrypted)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error decrypting tokens: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def create_session(
|
|
|
|
|
self,
|
|
|
|
|
user_id: str,
|
|
|
|
|
provider: str,
|
|
|
|
|
token: dict,
|
|
|
|
|
) -> Optional[OAuthSessionModel]:
|
|
|
|
|
"""Create a new OAuth session"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
current_time = int(time.time())
|
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
result = OAuthSession(
|
|
|
|
|
**{
|
|
|
|
|
"id": id,
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
"provider": provider,
|
|
|
|
|
"token": self._encrypt_token(token),
|
|
|
|
|
"expires_at": token.get("expires_at"),
|
|
|
|
|
"created_at": current_time,
|
|
|
|
|
"updated_at": current_time,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
db.add(result)
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(result)
|
|
|
|
|
|
|
|
|
|
if result:
|
|
|
|
|
result.token = token # Return decrypted token
|
|
|
|
|
return OAuthSessionModel.model_validate(result)
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error creating OAuth session: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
|
|
|
|
|
"""Get OAuth session by ID"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
|
|
|
|
if session:
|
|
|
|
|
session.token = self._decrypt_token(session.token)
|
|
|
|
|
return OAuthSessionModel.model_validate(session)
|
2025-09-08 14:09:01 +00:00
|
|
|
|
2025-09-08 14:05:43 +00:00
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error getting OAuth session by ID: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_session_by_id_and_user_id(
|
|
|
|
|
self, session_id: str, user_id: str
|
|
|
|
|
) -> Optional[OAuthSessionModel]:
|
|
|
|
|
"""Get OAuth session by ID and user ID"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
session = (
|
|
|
|
|
db.query(OAuthSession)
|
|
|
|
|
.filter_by(id=session_id, user_id=user_id)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if session:
|
|
|
|
|
session.token = self._decrypt_token(session.token)
|
|
|
|
|
return OAuthSessionModel.model_validate(session)
|
2025-09-08 14:09:01 +00:00
|
|
|
|
2025-09-08 14:05:43 +00:00
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error getting OAuth session by ID: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
2025-09-25 06:49:16 +00:00
|
|
|
def get_session_by_provider_and_user_id(
|
|
|
|
|
self, provider: str, user_id: str
|
|
|
|
|
) -> Optional[OAuthSessionModel]:
|
|
|
|
|
"""Get OAuth session by provider and user ID"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
session = (
|
|
|
|
|
db.query(OAuthSession)
|
|
|
|
|
.filter_by(provider=provider, user_id=user_id)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if session:
|
|
|
|
|
session.token = self._decrypt_token(session.token)
|
|
|
|
|
return OAuthSessionModel.model_validate(session)
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error getting OAuth session by provider and user ID: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
2025-09-08 14:05:43 +00:00
|
|
|
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
|
|
|
|
|
"""Get all OAuth sessions for a user"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
for session in sessions:
|
|
|
|
|
session.token = self._decrypt_token(session.token)
|
|
|
|
|
results.append(OAuthSessionModel.model_validate(session))
|
|
|
|
|
|
|
|
|
|
return results
|
2025-09-08 14:09:01 +00:00
|
|
|
|
2025-09-08 14:05:43 +00:00
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error getting OAuth sessions by user ID: {e}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
def update_session_by_id(
|
|
|
|
|
self, session_id: str, token: dict
|
|
|
|
|
) -> Optional[OAuthSessionModel]:
|
|
|
|
|
"""Update OAuth session tokens"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
current_time = int(time.time())
|
|
|
|
|
|
|
|
|
|
db.query(OAuthSession).filter_by(id=session_id).update(
|
|
|
|
|
{
|
|
|
|
|
"token": self._encrypt_token(token),
|
|
|
|
|
"expires_at": token.get("expires_at"),
|
|
|
|
|
"updated_at": current_time,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
db.commit()
|
|
|
|
|
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
|
|
|
|
|
|
|
|
|
if session:
|
|
|
|
|
session.token = self._decrypt_token(session.token)
|
|
|
|
|
return OAuthSessionModel.model_validate(session)
|
2025-09-08 14:09:01 +00:00
|
|
|
|
2025-09-08 14:05:43 +00:00
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error updating OAuth session tokens: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def delete_session_by_id(self, session_id: str) -> bool:
|
|
|
|
|
"""Delete an OAuth session"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
result = db.query(OAuthSession).filter_by(id=session_id).delete()
|
|
|
|
|
db.commit()
|
|
|
|
|
return result > 0
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error deleting OAuth session: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def delete_sessions_by_user_id(self, user_id: str) -> bool:
|
|
|
|
|
"""Delete all OAuth sessions for a user"""
|
|
|
|
|
try:
|
|
|
|
|
with get_db() as db:
|
|
|
|
|
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
|
|
|
|
|
db.commit()
|
|
|
|
|
return True
|
|
|
|
|
except Exception as e:
|
|
|
|
|
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OAuthSessions = OAuthSessionTable()
|