enh: revoked token handling

This commit is contained in:
Timothy Jaeryang Baek 2025-11-19 06:08:59 -05:00
parent e486490451
commit c4ecad0605
2 changed files with 83 additions and 22 deletions

View file

@ -46,6 +46,7 @@ from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.auth import (
decode_token,
invalidate_token,
create_api_key,
create_token,
get_admin_user,
@ -702,6 +703,19 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout")
async def signout(request: Request, response: Response):
# get auth token from headers or cookies
token = None
auth_header = request.headers.get("Authorization")
if auth_header:
auth_cred = get_http_authorization_cred(auth_header)
token = auth_cred.credentials
else:
token = request.cookies.get("token")
if token:
await invalidate_token(request, token)
response.delete_cookie("token")
response.delete_cookie("oui-session")
response.delete_cookie("oauth_id_token")

View file

@ -30,6 +30,7 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
OFFLINE_MODE,
LICENSE_BLOB,
REDIS_KEY_PREFIX,
pk,
WEBUI_SECRET_KEY,
TRUSTED_SIGNATURE_KEY,
@ -180,6 +181,9 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
expire = datetime.now(UTC) + expires_delta
payload.update({"exp": expire})
jti = str(uuid.uuid4())
payload.update({"jti": jti})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt
@ -192,6 +196,43 @@ def decode_token(token: str) -> Optional[dict]:
return None
async def is_valid_token(request, decoded) -> bool:
# Require Redis to check revoked tokens
if request.app.state.redis:
jti = decoded.get("jti")
if jti:
revoked = await request.app.state.redis.get(
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked"
)
if revoked:
return False
return True
async def invalidate_token(request, token):
decoded = decode_token(token)
# Require Redis to store revoked tokens
if request.app.state.redis:
jti = decoded.get("jti")
exp = decoded.get("exp")
if jti:
ttl = exp - int(
datetime.now(UTC).timestamp()
) # Calculate time-to-live for the token
if ttl > 0:
# Store the revoked token in Redis with an expiration time
await request.app.state.redis.set(
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked",
"1",
ex=ttl,
)
def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]
@ -211,7 +252,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
return None
def get_current_user(
async def get_current_user(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
@ -230,16 +271,7 @@ def get_current_user(
# auth by api key
if token.startswith("sk-"):
user = get_current_user_by_api_key(token)
if not request.state.enable_api_keys or not has_permission(
user.id,
"features.api_keys",
request.app.state.config.USER_PERMISSIONS,
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
user = get_current_user_by_api_key(request, token)
# Add user info to current span
current_span = trace.get_current_span()
@ -248,10 +280,10 @@ def get_current_user(
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
return user
# auth by jwt token
try:
try:
data = decode_token(token)
@ -262,6 +294,12 @@ def get_current_user(
)
if data is not None and "id" in data:
if data.get("jti") and not await is_valid_token(request, data):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
)
user = Users.get_user_by_id(data["id"])
if user is None:
raise HTTPException(
@ -314,7 +352,7 @@ def get_current_user(
raise e
def get_current_user_by_api_key(api_key: str):
def get_current_user_by_api_key(request, api_key: str):
user = Users.get_user_by_api_key(api_key)
if user is None:
@ -322,16 +360,25 @@ def get_current_user_by_api_key(api_key: str):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id)
if not request.state.enable_api_keys or not has_permission(
user.id,
"features.api_keys",
request.app.state.config.USER_PERMISSIONS,
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id)
return user