diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 983db4e04b..887a0ea722 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2967,6 +2967,29 @@ WEB_SEARCH_RESULT_COUNT = PersistentConfig( ) +# Web Search Rate Limiting Config + +WEB_SEARCH_RATE_LIMIT_ENABLED = PersistentConfig( + "WEB_SEARCH_RATE_LIMIT_ENABLED", + "rag.web.search.rate_limit.enabled", + os.getenv("WEB_SEARCH_RATE_LIMIT_ENABLED", "False").lower(), +) + +# The maximum number of requests that can be made to the web search engine per N seconds, where N=WEB_SEARCH_RATE_LIMIT_MIN_SECONDS +WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS = PersistentConfig( + "WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS", + "rag.web.search.rate_limit.max_requests", + int(os.getenv("WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS", "1")), +) + +# +WEB_SEARCH_RATE_LIMIT_MIN_SECONDS = PersistentConfig( + "WEB_SEARCH_RATE_LIMIT_MIN_SECONDS", + "rag.web.search.rate_limit.min_seconds", + int(os.getenv("WEB_SEARCH_RATE_MIN_SECONDS", "1")), +) + + # You can provide a list of your own websites to filter after performing a web search. # This ensures the highest level of safety and reliability of the information sources. WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 08ffde1733..ee3642e367 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -4,10 +4,12 @@ import mimetypes import os import shutil import asyncio +import time import re import uuid -from datetime import datetime +from datetime import datetime, timedelta +from multiprocessing import Value from pathlib import Path from typing import Iterator, List, Optional, Sequence, Union @@ -24,6 +26,9 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool + +from functools import wraps +from collections import deque from pydantic import BaseModel import tiktoken @@ -1766,7 +1771,62 @@ async def process_web( detail=ERROR_MESSAGES.DEFAULT(e), ) +from open_webui.config import ( + WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS, + WEB_SEARCH_RATE_LIMIT_MIN_SECONDS, + WEB_SEARCH_RATE_LIMIT_ENABLED, +) +# Rate Limit (Specifically for search: This references environment variables named for search) + + +web_search_lock = asyncio.Lock() +# Track timestamps of previous calls +web_search_timestamps = deque() + + +def search_rate_limit(max_calls: int, period: float): + """ + Async-friendly decorator for limiting function calls to `max_calls` per `period` seconds. + Works in FastAPI async endpoints without blocking the event loop. + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + async with web_search_lock: + now = asyncio.get_event_loop().time() + + # Remove timestamps older than the period + while web_search_timestamps and now - web_search_timestamps[0] > period: + web_search_timestamps.popleft() + + if len(web_search_timestamps) >= max_calls: + # Need to wait until oldest call is outside the period + wait_time = period - (now - web_search_timestamps[0]) + await asyncio.sleep(wait_time) + now = asyncio.get_event_loop().time() + while ( + web_search_timestamps + and now - web_search_timestamps[0] > period + ): + web_search_timestamps.popleft() + + # Record this call + web_search_timestamps.append(now) + + # Call the actual function + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +@search_rate_limit( + int(WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS.value), + int(WEB_SEARCH_RATE_LIMIT_MIN_SECONDS.value), +) def search_web( request: Request, engine: str, query: str, user=None ) -> list[SearchResult]: @@ -1792,6 +1852,8 @@ def search_web( query (str): The query to search for """ + logging.info(f"search_web: {engine} query: {query}") + # TODO: add playwright to search the web if engine == "ollama_cloud": return search_ollama_cloud( @@ -2087,6 +2149,7 @@ async def process_web_search( for result in search_results: if result: + result = await result for item in result: if item and item.link: result_items.append(item)