refac: decouple api key restrictions from get user

This commit is contained in:
Timothy Jaeryang Baek 2025-11-13 19:52:04 -05:00
parent e2ff2ae252
commit b160eef7eb
2 changed files with 41 additions and 23 deletions

View file

@ -1218,6 +1218,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
app.state.MODELS = {} app.state.MODELS = {}
# Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:
app.add_middleware(CompressMiddleware)
class RedirectMiddleware(BaseHTTPMiddleware): class RedirectMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
@ -1259,14 +1263,47 @@ class RedirectMiddleware(BaseHTTPMiddleware):
return response return response
# Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:
app.add_middleware(CompressMiddleware)
app.add_middleware(RedirectMiddleware) app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(SecurityHeadersMiddleware)
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
auth_header = request.headers.get("Authorization")
# Only apply restrictions if an sk- API key is used
if auth_header and auth_header.startswith("sk-"):
# Check if restrictions are enabled
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
).split(",")
if path.strip()
]
request_path = request.url.path
# Match exact path or prefix path
is_allowed = any(
request_path == allowed or request_path.startswith(allowed + "/")
for allowed in allowed_paths
)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API key not allowed to access this endpoint.",
)
response = await call_next(request)
return response
app.add_middleware(APIKeyRestrictionMiddleware)
@app.middleware("http") @app.middleware("http")
async def commit_session_after_request(request: Request, call_next): async def commit_session_after_request(request: Request, call_next):
response = await call_next(request) response = await call_next(request)

View file

@ -233,24 +233,6 @@ def get_current_user(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
) )
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
).split(",")
]
# Check if the request path matches any allowed endpoint.
if not any(
request.url.path == allowed
or request.url.path.startswith(allowed + "/")
for allowed in allowed_paths
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
user = get_current_user_by_api_key(token) user = get_current_user_by_api_key(token)
# Add user info to current span # Add user info to current span
@ -260,7 +242,6 @@ def get_current_user(
current_span.set_attribute("client.user.email", user.email) current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key") current_span.set_attribute("client.auth.type", "api_key")
return user return user
# auth by jwt token # auth by jwt token