mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac: decouple api key restrictions from get user
This commit is contained in:
parent
e2ff2ae252
commit
b160eef7eb
2 changed files with 41 additions and 23 deletions
|
|
@ -1218,6 +1218,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||||
|
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
|
|
||||||
|
# Add the middleware to the app
|
||||||
|
if ENABLE_COMPRESSION_MIDDLEWARE:
|
||||||
|
app.add_middleware(CompressMiddleware)
|
||||||
|
|
||||||
|
|
||||||
class RedirectMiddleware(BaseHTTPMiddleware):
|
class RedirectMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
|
@ -1259,14 +1263,47 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Add the middleware to the app
|
|
||||||
if ENABLE_COMPRESSION_MIDDLEWARE:
|
|
||||||
app.add_middleware(CompressMiddleware)
|
|
||||||
|
|
||||||
app.add_middleware(RedirectMiddleware)
|
app.add_middleware(RedirectMiddleware)
|
||||||
app.add_middleware(SecurityHeadersMiddleware)
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
|
||||||
|
# Only apply restrictions if an sk- API key is used
|
||||||
|
if auth_header and auth_header.startswith("sk-"):
|
||||||
|
# Check if restrictions are enabled
|
||||||
|
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
|
||||||
|
allowed_paths = [
|
||||||
|
path.strip()
|
||||||
|
for path in str(
|
||||||
|
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
|
||||||
|
).split(",")
|
||||||
|
if path.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
request_path = request.url.path
|
||||||
|
|
||||||
|
# Match exact path or prefix path
|
||||||
|
is_allowed = any(
|
||||||
|
request_path == allowed or request_path.startswith(allowed + "/")
|
||||||
|
for allowed in allowed_paths
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="API key not allowed to access this endpoint.",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
app.add_middleware(APIKeyRestrictionMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def commit_session_after_request(request: Request, call_next):
|
async def commit_session_after_request(request: Request, call_next):
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
|
||||||
|
|
@ -233,24 +233,6 @@ def get_current_user(
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
|
|
||||||
allowed_paths = [
|
|
||||||
path.strip()
|
|
||||||
for path in str(
|
|
||||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
|
|
||||||
).split(",")
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if the request path matches any allowed endpoint.
|
|
||||||
if not any(
|
|
||||||
request.url.path == allowed
|
|
||||||
or request.url.path.startswith(allowed + "/")
|
|
||||||
for allowed in allowed_paths
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
|
||||||
)
|
|
||||||
|
|
||||||
user = get_current_user_by_api_key(token)
|
user = get_current_user_by_api_key(token)
|
||||||
|
|
||||||
# Add user info to current span
|
# Add user info to current span
|
||||||
|
|
@ -260,7 +242,6 @@ def get_current_user(
|
||||||
current_span.set_attribute("client.user.email", user.email)
|
current_span.set_attribute("client.user.email", user.email)
|
||||||
current_span.set_attribute("client.user.role", user.role)
|
current_span.set_attribute("client.user.role", user.role)
|
||||||
current_span.set_attribute("client.auth.type", "api_key")
|
current_span.set_attribute("client.auth.type", "api_key")
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
# auth by jwt token
|
# auth by jwt token
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue