adding configurable rate limit for web search

This commit is contained in:
Christopher Vaz 2025-09-14 17:29:50 -07:00
parent e701c65db4
commit 15cb2e289d
3 changed files with 102 additions and 8 deletions

View file

@ -730,6 +730,7 @@ def load_oauth_providers():
} }
if FEISHU_CLIENT_ID.value and FEISHU_CLIENT_SECRET.value: if FEISHU_CLIENT_ID.value and FEISHU_CLIENT_SECRET.value:
def feishu_oauth_register(client: OAuth): def feishu_oauth_register(client: OAuth):
client.register( client.register(
name="feishu", name="feishu",
@ -2708,6 +2709,29 @@ WEB_SEARCH_RESULT_COUNT = PersistentConfig(
) )
# Web Search Rate Limiting Config
WEB_SEARCH_RATE_LIMIT_ENABLED = PersistentConfig(
"WEB_SEARCH_RATE_LIMIT_ENABLED",
"rag.web.search.rate_limit.enabled",
os.getenv("WEB_SEARCH_RATE_LIMIT_ENABLED", "False").lower(),
)
# The maximum number of requests that can be made to the web search engine per N seconds, where N=WEB_SEARCH_RATE_LIMIT_MIN_SECONDS
WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS = PersistentConfig(
"WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS",
"rag.web.search.rate_limit.max_requests",
int(os.getenv("WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS", "1")),
)
#
WEB_SEARCH_RATE_LIMIT_MIN_SECONDS = PersistentConfig(
"WEB_SEARCH_RATE_LIMIT_MIN_SECONDS",
"rag.web.search.rate_limit.min_seconds",
int(os.getenv("WEB_SEARCH_RATE_MIN_SECONDS", "1")),
)
# You can provide a list of your own websites to filter after performing a web search. # You can provide a list of your own websites to filter after performing a web search.
# This ensures the highest level of safety and reliability of the information sources. # This ensures the highest level of safety and reliability of the information sources.
WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(

View file

@ -4,9 +4,11 @@ import mimetypes
import os import os
import shutil import shutil
import asyncio import asyncio
import time
import uuid import uuid
from datetime import datetime from datetime import datetime, timedelta
from multiprocessing import Value
from pathlib import Path from pathlib import Path
from typing import Iterator, List, Optional, Sequence, Union from typing import Iterator, List, Optional, Sequence, Union
@ -23,6 +25,9 @@ from fastapi import (
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from functools import wraps
from collections import deque
from pydantic import BaseModel from pydantic import BaseModel
import tiktoken import tiktoken
@ -1738,7 +1743,63 @@ def process_web(
) )
def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: from open_webui.config import (
WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS,
WEB_SEARCH_RATE_LIMIT_MIN_SECONDS,
WEB_SEARCH_RATE_LIMIT_ENABLED,
)
# Rate Limit (Specifically for search: This references environment variables named for search)
web_search_lock = asyncio.Lock()
# Track timestamps of previous calls
web_search_timestamps = deque()
def search_rate_limit(max_calls: int, period: float):
"""
Async-friendly decorator for limiting function calls to `max_calls` per `period` seconds.
Works in FastAPI async endpoints without blocking the event loop.
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
async with web_search_lock:
now = asyncio.get_event_loop().time()
# Remove timestamps older than the period
while web_search_timestamps and now - web_search_timestamps[0] > period:
web_search_timestamps.popleft()
if len(web_search_timestamps) >= max_calls:
# Need to wait until oldest call is outside the period
wait_time = period - (now - web_search_timestamps[0])
await asyncio.sleep(wait_time)
now = asyncio.get_event_loop().time()
while (
web_search_timestamps
and now - web_search_timestamps[0] > period
):
web_search_timestamps.popleft()
# Record this call
web_search_timestamps.append(now)
# Call the actual function
return await func(*args, **kwargs)
return wrapper
return decorator
@search_rate_limit(
int(WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS.value),
int(WEB_SEARCH_RATE_LIMIT_MIN_SECONDS.value),
)
async def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects. """Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order: Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL - SEARXNG_QUERY_URL
@ -1761,6 +1822,8 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
query (str): The query to search for query (str): The query to search for
""" """
logging.info(f"search_web: {engine} query: {query}")
# TODO: add playwright to search the web # TODO: add playwright to search the web
if engine == "searxng": if engine == "searxng":
if request.app.state.config.SEARXNG_QUERY_URL: if request.app.state.config.SEARXNG_QUERY_URL:
@ -2001,11 +2064,13 @@ async def process_web_search(
) )
search_tasks = [ search_tasks = [
run_in_threadpool( asyncio.ensure_future(
search_web, run_in_threadpool(
request, search_web,
request.app.state.config.WEB_SEARCH_ENGINE, request,
query, request.app.state.config.WEB_SEARCH_ENGINE,
query,
)
) )
for query in form_data.queries for query in form_data.queries
] ]
@ -2014,6 +2079,7 @@ async def process_web_search(
for result in search_results: for result in search_results:
if result: if result:
result = await result
for item in result: for item in result:
if item and item.link: if item and item.link:
result_items.append(item) result_items.append(item)

View file

@ -602,7 +602,11 @@ class OAuthManager:
or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data) or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
): ):
user_data: UserInfo = await client.userinfo(token=token) user_data: UserInfo = await client.userinfo(token=token)
if provider == "feishu" and isinstance(user_data, dict) and "data" in user_data: if (
provider == "feishu"
and isinstance(user_data, dict)
and "data" in user_data
):
user_data = user_data["data"] user_data = user_data["data"]
if not user_data: if not user_data:
log.warning(f"OAuth callback failed, user data is missing: {token}") log.warning(f"OAuth callback failed, user data is missing: {token}")