mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
enh: revoked token handling
This commit is contained in:
parent
e486490451
commit
c4ecad0605
2 changed files with 83 additions and 22 deletions
|
|
@ -46,6 +46,7 @@ from pydantic import BaseModel
|
|||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.auth import (
|
||||
decode_token,
|
||||
invalidate_token,
|
||||
create_api_key,
|
||||
create_token,
|
||||
get_admin_user,
|
||||
|
|
@ -702,6 +703,19 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
|
||||
@router.get("/signout")
|
||||
async def signout(request: Request, response: Response):
|
||||
|
||||
# get auth token from headers or cookies
|
||||
token = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
auth_cred = get_http_authorization_cred(auth_header)
|
||||
token = auth_cred.credentials
|
||||
else:
|
||||
token = request.cookies.get("token")
|
||||
|
||||
if token:
|
||||
await invalidate_token(request, token)
|
||||
|
||||
response.delete_cookie("token")
|
||||
response.delete_cookie("oui-session")
|
||||
response.delete_cookie("oauth_id_token")
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from open_webui.constants import ERROR_MESSAGES
|
|||
from open_webui.env import (
|
||||
OFFLINE_MODE,
|
||||
LICENSE_BLOB,
|
||||
REDIS_KEY_PREFIX,
|
||||
pk,
|
||||
WEBUI_SECRET_KEY,
|
||||
TRUSTED_SIGNATURE_KEY,
|
||||
|
|
@ -180,6 +181,9 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
|
|||
expire = datetime.now(UTC) + expires_delta
|
||||
payload.update({"exp": expire})
|
||||
|
||||
jti = str(uuid.uuid4())
|
||||
payload.update({"jti": jti})
|
||||
|
||||
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
|
@ -192,6 +196,43 @@ def decode_token(token: str) -> Optional[dict]:
|
|||
return None
|
||||
|
||||
|
||||
async def is_valid_token(request, decoded) -> bool:
|
||||
# Require Redis to check revoked tokens
|
||||
if request.app.state.redis:
|
||||
jti = decoded.get("jti")
|
||||
|
||||
if jti:
|
||||
revoked = await request.app.state.redis.get(
|
||||
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked"
|
||||
)
|
||||
if revoked:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def invalidate_token(request, token):
|
||||
decoded = decode_token(token)
|
||||
|
||||
# Require Redis to store revoked tokens
|
||||
if request.app.state.redis:
|
||||
jti = decoded.get("jti")
|
||||
exp = decoded.get("exp")
|
||||
|
||||
if jti:
|
||||
ttl = exp - int(
|
||||
datetime.now(UTC).timestamp()
|
||||
) # Calculate time-to-live for the token
|
||||
|
||||
if ttl > 0:
|
||||
# Store the revoked token in Redis with an expiration time
|
||||
await request.app.state.redis.set(
|
||||
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked",
|
||||
"1",
|
||||
ex=ttl,
|
||||
)
|
||||
|
||||
|
||||
def extract_token_from_auth_header(auth_header: str):
|
||||
return auth_header[len("Bearer ") :]
|
||||
|
||||
|
|
@ -211,7 +252,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
|
|||
return None
|
||||
|
||||
|
||||
def get_current_user(
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
response: Response,
|
||||
background_tasks: BackgroundTasks,
|
||||
|
|
@ -230,16 +271,7 @@ def get_current_user(
|
|||
|
||||
# auth by api key
|
||||
if token.startswith("sk-"):
|
||||
user = get_current_user_by_api_key(token)
|
||||
|
||||
if not request.state.enable_api_keys or not has_permission(
|
||||
user.id,
|
||||
"features.api_keys",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
user = get_current_user_by_api_key(request, token)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
|
|
@ -248,10 +280,10 @@ def get_current_user(
|
|||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
return user
|
||||
|
||||
# auth by jwt token
|
||||
|
||||
try:
|
||||
try:
|
||||
data = decode_token(token)
|
||||
|
|
@ -262,6 +294,12 @@ def get_current_user(
|
|||
)
|
||||
|
||||
if data is not None and "id" in data:
|
||||
if data.get("jti") and not await is_valid_token(request, data):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -314,7 +352,7 @@ def get_current_user(
|
|||
raise e
|
||||
|
||||
|
||||
def get_current_user_by_api_key(api_key: str):
|
||||
def get_current_user_by_api_key(request, api_key: str):
|
||||
user = Users.get_user_by_api_key(api_key)
|
||||
|
||||
if user is None:
|
||||
|
|
@ -322,7 +360,16 @@ def get_current_user_by_api_key(api_key: str):
|
|||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
|
||||
if not request.state.enable_api_keys or not has_permission(
|
||||
user.id,
|
||||
"features.api_keys",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
|
|
|
|||
Loading…
Reference in a new issue