mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
enh: revoked token handling
This commit is contained in:
parent
e486490451
commit
c4ecad0605
2 changed files with 83 additions and 22 deletions
|
|
@ -46,6 +46,7 @@ from pydantic import BaseModel
|
||||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||||
from open_webui.utils.auth import (
|
from open_webui.utils.auth import (
|
||||||
decode_token,
|
decode_token,
|
||||||
|
invalidate_token,
|
||||||
create_api_key,
|
create_api_key,
|
||||||
create_token,
|
create_token,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
|
|
@ -702,6 +703,19 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
|
||||||
@router.get("/signout")
|
@router.get("/signout")
|
||||||
async def signout(request: Request, response: Response):
|
async def signout(request: Request, response: Response):
|
||||||
|
|
||||||
|
# get auth token from headers or cookies
|
||||||
|
token = None
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header:
|
||||||
|
auth_cred = get_http_authorization_cred(auth_header)
|
||||||
|
token = auth_cred.credentials
|
||||||
|
else:
|
||||||
|
token = request.cookies.get("token")
|
||||||
|
|
||||||
|
if token:
|
||||||
|
await invalidate_token(request, token)
|
||||||
|
|
||||||
response.delete_cookie("token")
|
response.delete_cookie("token")
|
||||||
response.delete_cookie("oui-session")
|
response.delete_cookie("oui-session")
|
||||||
response.delete_cookie("oauth_id_token")
|
response.delete_cookie("oauth_id_token")
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
OFFLINE_MODE,
|
OFFLINE_MODE,
|
||||||
LICENSE_BLOB,
|
LICENSE_BLOB,
|
||||||
|
REDIS_KEY_PREFIX,
|
||||||
pk,
|
pk,
|
||||||
WEBUI_SECRET_KEY,
|
WEBUI_SECRET_KEY,
|
||||||
TRUSTED_SIGNATURE_KEY,
|
TRUSTED_SIGNATURE_KEY,
|
||||||
|
|
@ -180,6 +181,9 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
|
||||||
expire = datetime.now(UTC) + expires_delta
|
expire = datetime.now(UTC) + expires_delta
|
||||||
payload.update({"exp": expire})
|
payload.update({"exp": expire})
|
||||||
|
|
||||||
|
jti = str(uuid.uuid4())
|
||||||
|
payload.update({"jti": jti})
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
@ -192,6 +196,43 @@ def decode_token(token: str) -> Optional[dict]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def is_valid_token(request, decoded) -> bool:
|
||||||
|
# Require Redis to check revoked tokens
|
||||||
|
if request.app.state.redis:
|
||||||
|
jti = decoded.get("jti")
|
||||||
|
|
||||||
|
if jti:
|
||||||
|
revoked = await request.app.state.redis.get(
|
||||||
|
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked"
|
||||||
|
)
|
||||||
|
if revoked:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_token(request, token):
|
||||||
|
decoded = decode_token(token)
|
||||||
|
|
||||||
|
# Require Redis to store revoked tokens
|
||||||
|
if request.app.state.redis:
|
||||||
|
jti = decoded.get("jti")
|
||||||
|
exp = decoded.get("exp")
|
||||||
|
|
||||||
|
if jti:
|
||||||
|
ttl = exp - int(
|
||||||
|
datetime.now(UTC).timestamp()
|
||||||
|
) # Calculate time-to-live for the token
|
||||||
|
|
||||||
|
if ttl > 0:
|
||||||
|
# Store the revoked token in Redis with an expiration time
|
||||||
|
await request.app.state.redis.set(
|
||||||
|
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked",
|
||||||
|
"1",
|
||||||
|
ex=ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_token_from_auth_header(auth_header: str):
|
def extract_token_from_auth_header(auth_header: str):
|
||||||
return auth_header[len("Bearer ") :]
|
return auth_header[len("Bearer ") :]
|
||||||
|
|
||||||
|
|
@ -211,7 +252,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(
|
async def get_current_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
|
|
@ -230,16 +271,7 @@ def get_current_user(
|
||||||
|
|
||||||
# auth by api key
|
# auth by api key
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
user = get_current_user_by_api_key(token)
|
user = get_current_user_by_api_key(request, token)
|
||||||
|
|
||||||
if not request.state.enable_api_keys or not has_permission(
|
|
||||||
user.id,
|
|
||||||
"features.api_keys",
|
|
||||||
request.app.state.config.USER_PERMISSIONS,
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add user info to current span
|
# Add user info to current span
|
||||||
current_span = trace.get_current_span()
|
current_span = trace.get_current_span()
|
||||||
|
|
@ -248,10 +280,10 @@ def get_current_user(
|
||||||
current_span.set_attribute("client.user.email", user.email)
|
current_span.set_attribute("client.user.email", user.email)
|
||||||
current_span.set_attribute("client.user.role", user.role)
|
current_span.set_attribute("client.user.role", user.role)
|
||||||
current_span.set_attribute("client.auth.type", "api_key")
|
current_span.set_attribute("client.auth.type", "api_key")
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
# auth by jwt token
|
# auth by jwt token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
data = decode_token(token)
|
data = decode_token(token)
|
||||||
|
|
@ -262,6 +294,12 @@ def get_current_user(
|
||||||
)
|
)
|
||||||
|
|
||||||
if data is not None and "id" in data:
|
if data is not None and "id" in data:
|
||||||
|
if data.get("jti") and not await is_valid_token(request, data):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token",
|
||||||
|
)
|
||||||
|
|
||||||
user = Users.get_user_by_id(data["id"])
|
user = Users.get_user_by_id(data["id"])
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -314,7 +352,7 @@ def get_current_user(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_by_api_key(api_key: str):
|
def get_current_user_by_api_key(request, api_key: str):
|
||||||
user = Users.get_user_by_api_key(api_key)
|
user = Users.get_user_by_api_key(api_key)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
|
|
@ -322,7 +360,16 @@ def get_current_user_by_api_key(api_key: str):
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
if not request.state.enable_api_keys or not has_permission(
|
||||||
|
user.id,
|
||||||
|
"features.api_keys",
|
||||||
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||||
|
)
|
||||||
|
|
||||||
# Add user info to current span
|
# Add user info to current span
|
||||||
current_span = trace.get_current_span()
|
current_span = trace.get_current_span()
|
||||||
if current_span:
|
if current_span:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue