From 7b166370432414ce8f186747fb098e0c70fb2d6b Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 2 Dec 2025 03:52:38 -0500 Subject: [PATCH] feat: signin rate limit --- backend/open_webui/routers/auths.py | 15 +++ backend/open_webui/utils/rate_limit.py | 139 +++++++++++++++++++++++++ backend/open_webui/utils/redis.py | 23 +++- 3 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 backend/open_webui/utils/rate_limit.py diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 42302043ed..0bf1d65d0c 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -6,6 +6,7 @@ import logging from aiohttp import ClientSession import urllib + from open_webui.models.auths import ( AddUserForm, ApiKey, @@ -65,6 +66,10 @@ from open_webui.utils.auth import ( from open_webui.utils.webhook import post_webhook from open_webui.utils.access_control import get_permissions, has_permission +from open_webui.utils.redis import get_redis_client +from open_webui.utils.rate_limit import RateLimiter + + from typing import Optional, List from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS @@ -77,6 +82,10 @@ router = APIRouter() log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) +signin_rate_limiter = RateLimiter( + redis_client=get_redis_client(), limit=5 * 3, window=60 * 3 +) + ############################ # GetSessionUser ############################ @@ -551,6 +560,12 @@ async def signin(request: Request, response: Response, form_data: SigninForm): admin_email.lower(), lambda pw: verify_password(admin_password, pw) ) else: + if signin_rate_limiter.is_limited(form_data.email.lower()): + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED, + ) + password_bytes = form_data.password.encode("utf-8") if len(password_bytes) > 72: # TODO: Implement other hashing algorithms that support longer passwords diff --git a/backend/open_webui/utils/rate_limit.py b/backend/open_webui/utils/rate_limit.py new file mode 100644 index 0000000000..b657a937ab --- /dev/null +++ b/backend/open_webui/utils/rate_limit.py @@ -0,0 +1,139 @@ +import time +from typing import Optional, Dict +from open_webui.env import REDIS_KEY_PREFIX + + +class RateLimiter: + """ + General-purpose rate limiter using Redis with a rolling window strategy. + Falls back to in-memory storage if Redis is not available. + """ + + # In-memory fallback storage + _memory_store: Dict[str, Dict[int, int]] = {} + + def __init__( + self, + redis_client, + limit: int, + window: int, + bucket_size: int = 60, + enabled: bool = True, + ): + """ + :param redis_client: Redis client instance or None + :param limit: Max allowed events in the window + :param window: Time window in seconds + :param bucket_size: Bucket resolution + :param enabled: Turn on/off rate limiting globally + """ + self.r = redis_client + self.limit = limit + self.window = window + self.bucket_size = bucket_size + self.num_buckets = window // bucket_size + self.enabled = enabled + + def _bucket_key(self, key: str, bucket_index: int) -> str: + return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}" + + def _current_bucket(self) -> int: + return int(time.time()) // self.bucket_size + + def _redis_available(self) -> bool: + return self.r is not None + + def is_limited(self, key: str) -> bool: + """ + Main rate-limit check. + Gracefully handles missing or failing Redis. + """ + if not self.enabled: + return False + + if self._redis_available(): + try: + return self._is_limited_redis(key) + except Exception: + return self._is_limited_memory(key) + else: + return self._is_limited_memory(key) + + def get_count(self, key: str) -> int: + if not self.enabled: + return 0 + + if self._redis_available(): + try: + return self._get_count_redis(key) + except Exception: + return self._get_count_memory(key) + else: + return self._get_count_memory(key) + + def remaining(self, key: str) -> int: + used = self.get_count(key) + return max(0, self.limit - used) + + def _is_limited_redis(self, key: str) -> bool: + now_bucket = self._current_bucket() + bucket_key = self._bucket_key(key, now_bucket) + + attempts = self.r.incr(bucket_key) + if attempts == 1: + self.r.expire(bucket_key, self.window + self.bucket_size) + + # Collect buckets + buckets = [ + self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) + ] + + counts = self.r.mget(buckets) + total = sum(int(c) for c in counts if c) + + return total > self.limit + + def _get_count_redis(self, key: str) -> int: + now_bucket = self._current_bucket() + buckets = [ + self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) + ] + counts = self.r.mget(buckets) + return sum(int(c) for c in counts if c) + + def _is_limited_memory(self, key: str) -> bool: + now_bucket = self._current_bucket() + + # Init storage + if key not in self._memory_store: + self._memory_store[key] = {} + + store = self._memory_store[key] + + # Increment bucket + store[now_bucket] = store.get(now_bucket, 0) + 1 + + # Drop expired buckets + min_bucket = now_bucket - self.num_buckets + expired = [b for b in store if b < min_bucket] + for b in expired: + del store[b] + + # Count totals + total = sum(store.values()) + return total > self.limit + + def _get_count_memory(self, key: str) -> int: + now_bucket = self._current_bucket() + if key not in self._memory_store: + return 0 + + store = self._memory_store[key] + min_bucket = now_bucket - self.num_buckets + + # Remove expired + expired = [b for b in store if b < min_bucket] + for b in expired: + del store[b] + + return sum(store.values()) diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index c60a6fa517..cc29ce6683 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -5,7 +5,13 @@ import logging import redis -from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT +from open_webui.env import ( + REDIS_CLUSTER, + REDIS_SENTINEL_HOSTS, + REDIS_SENTINEL_MAX_RETRY_COUNT, + REDIS_SENTINEL_PORT, + REDIS_URL, +) log = logging.getLogger(__name__) @@ -108,6 +114,21 @@ def parse_redis_service_url(redis_url): } +def get_redis_client(async_mode=False): + try: + return get_redis_connection( + redis_url=REDIS_URL, + redis_sentinels=get_sentinels_from_env( + REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT + ), + redis_cluster=REDIS_CLUSTER, + async_mode=async_mode, + ) + except Exception as e: + log.debug(f"Failed to get Redis client: {e}") + return None + + def get_redis_connection( redis_url, redis_sentinels,