diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 0dc1c34f89..7aeb44ea06 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1218,6 +1218,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( app.state.MODELS = {} +# Add the middleware to the app +if ENABLE_COMPRESSION_MIDDLEWARE: + app.add_middleware(CompressMiddleware) + class RedirectMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): @@ -1259,14 +1263,47 @@ class RedirectMiddleware(BaseHTTPMiddleware): return response -# Add the middleware to the app -if ENABLE_COMPRESSION_MIDDLEWARE: - app.add_middleware(CompressMiddleware) - app.add_middleware(RedirectMiddleware) app.add_middleware(SecurityHeadersMiddleware) +class APIKeyRestrictionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + auth_header = request.headers.get("Authorization") + + # Only apply restrictions if an sk- API key is used + if auth_header and auth_header.startswith("sk-"): + # Check if restrictions are enabled + if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: + allowed_paths = [ + path.strip() + for path in str( + request.app.state.config.API_KEY_ALLOWED_ENDPOINTS + ).split(",") + if path.strip() + ] + + request_path = request.url.path + + # Match exact path or prefix path + is_allowed = any( + request_path == allowed or request_path.startswith(allowed + "/") + for allowed in allowed_paths + ) + + if not is_allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="API key not allowed to access this endpoint.", + ) + + response = await call_next(request) + return response + + +app.add_middleware(APIKeyRestrictionMiddleware) + + @app.middleware("http") async def commit_session_after_request(request: Request, call_next): response = await call_next(request) diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index e34803ade1..b7c49de442 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -233,24 +233,6 @@ def get_current_user( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED ) - if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: - allowed_paths = [ - path.strip() - for path in str( - request.app.state.config.API_KEY_ALLOWED_ENDPOINTS - ).split(",") - ] - - # Check if the request path matches any allowed endpoint. - if not any( - request.url.path == allowed - or request.url.path.startswith(allowed + "/") - for allowed in allowed_paths - ): - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED - ) - user = get_current_user_by_api_key(token) # Add user info to current span @@ -260,7 +242,6 @@ def get_current_user( current_span.set_attribute("client.user.email", user.email) current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.auth.type", "api_key") - return user # auth by jwt token