mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac: decouple api key restrictions from get user
This commit is contained in:
parent
e2ff2ae252
commit
b160eef7eb
2 changed files with 41 additions and 23 deletions
|
|
@ -1218,6 +1218,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
|||
|
||||
app.state.MODELS = {}
|
||||
|
||||
# Add the middleware to the app
|
||||
if ENABLE_COMPRESSION_MIDDLEWARE:
|
||||
app.add_middleware(CompressMiddleware)
|
||||
|
||||
|
||||
class RedirectMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
|
|
@ -1259,14 +1263,47 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
|||
return response
|
||||
|
||||
|
||||
# Add the middleware to the app
|
||||
if ENABLE_COMPRESSION_MIDDLEWARE:
|
||||
app.add_middleware(CompressMiddleware)
|
||||
|
||||
app.add_middleware(RedirectMiddleware)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
# Only apply restrictions if an sk- API key is used
|
||||
if auth_header and auth_header.startswith("sk-"):
|
||||
# Check if restrictions are enabled
|
||||
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
|
||||
allowed_paths = [
|
||||
path.strip()
|
||||
for path in str(
|
||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
|
||||
).split(",")
|
||||
if path.strip()
|
||||
]
|
||||
|
||||
request_path = request.url.path
|
||||
|
||||
# Match exact path or prefix path
|
||||
is_allowed = any(
|
||||
request_path == allowed or request_path.startswith(allowed + "/")
|
||||
for allowed in allowed_paths
|
||||
)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="API key not allowed to access this endpoint.",
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(APIKeyRestrictionMiddleware)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def commit_session_after_request(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
|
|
|||
|
|
@ -233,24 +233,6 @@ def get_current_user(
|
|||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
|
||||
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
|
||||
allowed_paths = [
|
||||
path.strip()
|
||||
for path in str(
|
||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
|
||||
).split(",")
|
||||
]
|
||||
|
||||
# Check if the request path matches any allowed endpoint.
|
||||
if not any(
|
||||
request.url.path == allowed
|
||||
or request.url.path.startswith(allowed + "/")
|
||||
for allowed in allowed_paths
|
||||
):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
|
||||
user = get_current_user_by_api_key(token)
|
||||
|
||||
# Add user info to current span
|
||||
|
|
@ -260,7 +242,6 @@ def get_current_user(
|
|||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
return user
|
||||
|
||||
# auth by jwt token
|
||||
|
|
|
|||
Loading…
Reference in a new issue