mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
feat: signin rate limit
This commit is contained in:
parent
734c04ebf0
commit
7b16637043
3 changed files with 176 additions and 1 deletions
|
|
@ -6,6 +6,7 @@ import logging
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.auths import (
|
from open_webui.models.auths import (
|
||||||
AddUserForm,
|
AddUserForm,
|
||||||
ApiKey,
|
ApiKey,
|
||||||
|
|
@ -65,6 +66,10 @@ from open_webui.utils.auth import (
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.access_control import get_permissions, has_permission
|
from open_webui.utils.access_control import get_permissions, has_permission
|
||||||
|
|
||||||
|
from open_webui.utils.redis import get_redis_client
|
||||||
|
from open_webui.utils.rate_limit import RateLimiter
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
|
from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
|
||||||
|
|
@ -77,6 +82,10 @@ router = APIRouter()
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
signin_rate_limiter = RateLimiter(
|
||||||
|
redis_client=get_redis_client(), limit=5 * 3, window=60 * 3
|
||||||
|
)
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetSessionUser
|
# GetSessionUser
|
||||||
############################
|
############################
|
||||||
|
|
@ -551,6 +560,12 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if signin_rate_limiter.is_limited(form_data.email.lower()):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
|
||||||
|
)
|
||||||
|
|
||||||
password_bytes = form_data.password.encode("utf-8")
|
password_bytes = form_data.password.encode("utf-8")
|
||||||
if len(password_bytes) > 72:
|
if len(password_bytes) > 72:
|
||||||
# TODO: Implement other hashing algorithms that support longer passwords
|
# TODO: Implement other hashing algorithms that support longer passwords
|
||||||
|
|
|
||||||
139
backend/open_webui/utils/rate_limit.py
Normal file
139
backend/open_webui/utils/rate_limit.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict
|
||||||
|
from open_webui.env import REDIS_KEY_PREFIX
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
General-purpose rate limiter using Redis with a rolling window strategy.
|
||||||
|
Falls back to in-memory storage if Redis is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# In-memory fallback storage
|
||||||
|
_memory_store: Dict[str, Dict[int, int]] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_client,
|
||||||
|
limit: int,
|
||||||
|
window: int,
|
||||||
|
bucket_size: int = 60,
|
||||||
|
enabled: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param redis_client: Redis client instance or None
|
||||||
|
:param limit: Max allowed events in the window
|
||||||
|
:param window: Time window in seconds
|
||||||
|
:param bucket_size: Bucket resolution
|
||||||
|
:param enabled: Turn on/off rate limiting globally
|
||||||
|
"""
|
||||||
|
self.r = redis_client
|
||||||
|
self.limit = limit
|
||||||
|
self.window = window
|
||||||
|
self.bucket_size = bucket_size
|
||||||
|
self.num_buckets = window // bucket_size
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
|
def _bucket_key(self, key: str, bucket_index: int) -> str:
|
||||||
|
return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}"
|
||||||
|
|
||||||
|
def _current_bucket(self) -> int:
|
||||||
|
return int(time.time()) // self.bucket_size
|
||||||
|
|
||||||
|
def _redis_available(self) -> bool:
|
||||||
|
return self.r is not None
|
||||||
|
|
||||||
|
def is_limited(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Main rate-limit check.
|
||||||
|
Gracefully handles missing or failing Redis.
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._redis_available():
|
||||||
|
try:
|
||||||
|
return self._is_limited_redis(key)
|
||||||
|
except Exception:
|
||||||
|
return self._is_limited_memory(key)
|
||||||
|
else:
|
||||||
|
return self._is_limited_memory(key)
|
||||||
|
|
||||||
|
def get_count(self, key: str) -> int:
|
||||||
|
if not self.enabled:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if self._redis_available():
|
||||||
|
try:
|
||||||
|
return self._get_count_redis(key)
|
||||||
|
except Exception:
|
||||||
|
return self._get_count_memory(key)
|
||||||
|
else:
|
||||||
|
return self._get_count_memory(key)
|
||||||
|
|
||||||
|
def remaining(self, key: str) -> int:
|
||||||
|
used = self.get_count(key)
|
||||||
|
return max(0, self.limit - used)
|
||||||
|
|
||||||
|
def _is_limited_redis(self, key: str) -> bool:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
bucket_key = self._bucket_key(key, now_bucket)
|
||||||
|
|
||||||
|
attempts = self.r.incr(bucket_key)
|
||||||
|
if attempts == 1:
|
||||||
|
self.r.expire(bucket_key, self.window + self.bucket_size)
|
||||||
|
|
||||||
|
# Collect buckets
|
||||||
|
buckets = [
|
||||||
|
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
counts = self.r.mget(buckets)
|
||||||
|
total = sum(int(c) for c in counts if c)
|
||||||
|
|
||||||
|
return total > self.limit
|
||||||
|
|
||||||
|
def _get_count_redis(self, key: str) -> int:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
buckets = [
|
||||||
|
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
|
||||||
|
]
|
||||||
|
counts = self.r.mget(buckets)
|
||||||
|
return sum(int(c) for c in counts if c)
|
||||||
|
|
||||||
|
def _is_limited_memory(self, key: str) -> bool:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
|
||||||
|
# Init storage
|
||||||
|
if key not in self._memory_store:
|
||||||
|
self._memory_store[key] = {}
|
||||||
|
|
||||||
|
store = self._memory_store[key]
|
||||||
|
|
||||||
|
# Increment bucket
|
||||||
|
store[now_bucket] = store.get(now_bucket, 0) + 1
|
||||||
|
|
||||||
|
# Drop expired buckets
|
||||||
|
min_bucket = now_bucket - self.num_buckets
|
||||||
|
expired = [b for b in store if b < min_bucket]
|
||||||
|
for b in expired:
|
||||||
|
del store[b]
|
||||||
|
|
||||||
|
# Count totals
|
||||||
|
total = sum(store.values())
|
||||||
|
return total > self.limit
|
||||||
|
|
||||||
|
def _get_count_memory(self, key: str) -> int:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
if key not in self._memory_store:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
store = self._memory_store[key]
|
||||||
|
min_bucket = now_bucket - self.num_buckets
|
||||||
|
|
||||||
|
# Remove expired
|
||||||
|
expired = [b for b in store if b < min_bucket]
|
||||||
|
for b in expired:
|
||||||
|
del store[b]
|
||||||
|
|
||||||
|
return sum(store.values())
|
||||||
|
|
@ -5,7 +5,13 @@ import logging
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
|
from open_webui.env import (
|
||||||
|
REDIS_CLUSTER,
|
||||||
|
REDIS_SENTINEL_HOSTS,
|
||||||
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
||||||
|
REDIS_SENTINEL_PORT,
|
||||||
|
REDIS_URL,
|
||||||
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -108,6 +114,21 @@ def parse_redis_service_url(redis_url):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_client(async_mode=False):
|
||||||
|
try:
|
||||||
|
return get_redis_connection(
|
||||||
|
redis_url=REDIS_URL,
|
||||||
|
redis_sentinels=get_sentinels_from_env(
|
||||||
|
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||||
|
),
|
||||||
|
redis_cluster=REDIS_CLUSTER,
|
||||||
|
async_mode=async_mode,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to get Redis client: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_redis_connection(
|
def get_redis_connection(
|
||||||
redis_url,
|
redis_url,
|
||||||
redis_sentinels,
|
redis_sentinels,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue