diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 207d4405ee..a6b4ff967e 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -730,6 +730,7 @@ def load_oauth_providers(): } if FEISHU_CLIENT_ID.value and FEISHU_CLIENT_SECRET.value: + def feishu_oauth_register(client: OAuth): client.register( name="feishu", @@ -2708,6 +2709,29 @@ WEB_SEARCH_RESULT_COUNT = PersistentConfig( ) +# Web Search Rate Limiting Config + +WEB_SEARCH_RATE_LIMIT_ENABLED = PersistentConfig( + "WEB_SEARCH_RATE_LIMIT_ENABLED", + "rag.web.search.rate_limit.enabled", + os.getenv("WEB_SEARCH_RATE_LIMIT_ENABLED", "False").lower(), +) + +# The maximum number of requests that can be made to the web search engine per N seconds, where N=WEB_SEARCH_RATE_LIMIT_MIN_SECONDS +WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS = PersistentConfig( + "WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS", + "rag.web.search.rate_limit.max_requests", + int(os.getenv("WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS", "1")), +) + +# +WEB_SEARCH_RATE_LIMIT_MIN_SECONDS = PersistentConfig( + "WEB_SEARCH_RATE_LIMIT_MIN_SECONDS", + "rag.web.search.rate_limit.min_seconds", + int(os.getenv("WEB_SEARCH_RATE_MIN_SECONDS", "1")), +) + + # You can provide a list of your own websites to filter after performing a web search. # This ensures the highest level of safety and reliability of the information sources. WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 1f32791ba6..84a5c37dbf 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -4,9 +4,11 @@ import mimetypes import os import shutil import asyncio +import time import uuid -from datetime import datetime +from datetime import datetime, timedelta +from multiprocessing import Value from pathlib import Path from typing import Iterator, List, Optional, Sequence, Union @@ -23,6 +25,9 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool + +from functools import wraps +from collections import deque from pydantic import BaseModel import tiktoken @@ -1738,7 +1743,63 @@ def process_web( ) -def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: +from open_webui.config import ( + WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS, + WEB_SEARCH_RATE_LIMIT_MIN_SECONDS, + WEB_SEARCH_RATE_LIMIT_ENABLED, +) + +# Rate Limit (Specifically for search: This references environment variables named for search) + + +web_search_lock = asyncio.Lock() +# Track timestamps of previous calls +web_search_timestamps = deque() + + +def search_rate_limit(max_calls: int, period: float): + """ + Async-friendly decorator for limiting function calls to `max_calls` per `period` seconds. + Works in FastAPI async endpoints without blocking the event loop. + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + async with web_search_lock: + now = asyncio.get_event_loop().time() + + # Remove timestamps older than the period + while web_search_timestamps and now - web_search_timestamps[0] > period: + web_search_timestamps.popleft() + + if len(web_search_timestamps) >= max_calls: + # Need to wait until oldest call is outside the period + wait_time = period - (now - web_search_timestamps[0]) + await asyncio.sleep(wait_time) + now = asyncio.get_event_loop().time() + while ( + web_search_timestamps + and now - web_search_timestamps[0] > period + ): + web_search_timestamps.popleft() + + # Record this call + web_search_timestamps.append(now) + + # Call the actual function + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +@search_rate_limit( + int(WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS.value), + int(WEB_SEARCH_RATE_LIMIT_MIN_SECONDS.value), +) +async def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL @@ -1761,6 +1822,8 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: query (str): The query to search for """ + logging.info(f"search_web: {engine} query: {query}") + # TODO: add playwright to search the web if engine == "searxng": if request.app.state.config.SEARXNG_QUERY_URL: @@ -2001,11 +2064,13 @@ async def process_web_search( ) search_tasks = [ - run_in_threadpool( - search_web, - request, - request.app.state.config.WEB_SEARCH_ENGINE, - query, + asyncio.ensure_future( + run_in_threadpool( + search_web, + request, + request.app.state.config.WEB_SEARCH_ENGINE, + query, + ) ) for query in form_data.queries ] @@ -2014,6 +2079,7 @@ async def process_web_search( for result in search_results: if result: + result = await result for item in result: if item and item.link: result_items.append(item) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index f48a1756b7..9090c38ce5 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -602,7 +602,11 @@ class OAuthManager: or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data) ): user_data: UserInfo = await client.userinfo(token=token) - if provider == "feishu" and isinstance(user_data, dict) and "data" in user_data: + if ( + provider == "feishu" + and isinstance(user_data, dict) + and "data" in user_data + ): user_data = user_data["data"] if not user_data: log.warning(f"OAuth callback failed, user data is missing: {token}")