mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
adding configurable rate limit for web search
This commit is contained in:
parent
e701c65db4
commit
15cb2e289d
3 changed files with 102 additions and 8 deletions
|
|
@ -730,6 +730,7 @@ def load_oauth_providers():
|
||||||
}
|
}
|
||||||
|
|
||||||
if FEISHU_CLIENT_ID.value and FEISHU_CLIENT_SECRET.value:
|
if FEISHU_CLIENT_ID.value and FEISHU_CLIENT_SECRET.value:
|
||||||
|
|
||||||
def feishu_oauth_register(client: OAuth):
|
def feishu_oauth_register(client: OAuth):
|
||||||
client.register(
|
client.register(
|
||||||
name="feishu",
|
name="feishu",
|
||||||
|
|
@ -2708,6 +2709,29 @@ WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Web Search Rate Limiting Config
|
||||||
|
|
||||||
|
WEB_SEARCH_RATE_LIMIT_ENABLED = PersistentConfig(
|
||||||
|
"WEB_SEARCH_RATE_LIMIT_ENABLED",
|
||||||
|
"rag.web.search.rate_limit.enabled",
|
||||||
|
os.getenv("WEB_SEARCH_RATE_LIMIT_ENABLED", "False").lower(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The maximum number of requests that can be made to the web search engine per N seconds, where N=WEB_SEARCH_RATE_LIMIT_MIN_SECONDS
|
||||||
|
WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS = PersistentConfig(
|
||||||
|
"WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS",
|
||||||
|
"rag.web.search.rate_limit.max_requests",
|
||||||
|
int(os.getenv("WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS", "1")),
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
WEB_SEARCH_RATE_LIMIT_MIN_SECONDS = PersistentConfig(
|
||||||
|
"WEB_SEARCH_RATE_LIMIT_MIN_SECONDS",
|
||||||
|
"rag.web.search.rate_limit.min_seconds",
|
||||||
|
int(os.getenv("WEB_SEARCH_RATE_MIN_SECONDS", "1")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# You can provide a list of your own websites to filter after performing a web search.
|
# You can provide a list of your own websites to filter after performing a web search.
|
||||||
# This ensures the highest level of safety and reliability of the information sources.
|
# This ensures the highest level of safety and reliability of the information sources.
|
||||||
WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,11 @@ import mimetypes
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
|
from multiprocessing import Value
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, List, Optional, Sequence, Union
|
from typing import Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
|
@ -23,6 +25,9 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from collections import deque
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
@ -1738,7 +1743,63 @@ def process_web(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
from open_webui.config import (
|
||||||
|
WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS,
|
||||||
|
WEB_SEARCH_RATE_LIMIT_MIN_SECONDS,
|
||||||
|
WEB_SEARCH_RATE_LIMIT_ENABLED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rate Limit (Specifically for search: This references environment variables named for search)
|
||||||
|
|
||||||
|
|
||||||
|
web_search_lock = asyncio.Lock()
|
||||||
|
# Track timestamps of previous calls
|
||||||
|
web_search_timestamps = deque()
|
||||||
|
|
||||||
|
|
||||||
|
def search_rate_limit(max_calls: int, period: float):
|
||||||
|
"""
|
||||||
|
Async-friendly decorator for limiting function calls to `max_calls` per `period` seconds.
|
||||||
|
Works in FastAPI async endpoints without blocking the event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
async with web_search_lock:
|
||||||
|
now = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# Remove timestamps older than the period
|
||||||
|
while web_search_timestamps and now - web_search_timestamps[0] > period:
|
||||||
|
web_search_timestamps.popleft()
|
||||||
|
|
||||||
|
if len(web_search_timestamps) >= max_calls:
|
||||||
|
# Need to wait until oldest call is outside the period
|
||||||
|
wait_time = period - (now - web_search_timestamps[0])
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
now = asyncio.get_event_loop().time()
|
||||||
|
while (
|
||||||
|
web_search_timestamps
|
||||||
|
and now - web_search_timestamps[0] > period
|
||||||
|
):
|
||||||
|
web_search_timestamps.popleft()
|
||||||
|
|
||||||
|
# Record this call
|
||||||
|
web_search_timestamps.append(now)
|
||||||
|
|
||||||
|
# Call the actual function
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@search_rate_limit(
|
||||||
|
int(WEB_SEARCH_RATE_LIMIT_MAX_REQUESTS.value),
|
||||||
|
int(WEB_SEARCH_RATE_LIMIT_MIN_SECONDS.value),
|
||||||
|
)
|
||||||
|
async def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
||||||
Will look for a search engine API key in environment variables in the following order:
|
Will look for a search engine API key in environment variables in the following order:
|
||||||
- SEARXNG_QUERY_URL
|
- SEARXNG_QUERY_URL
|
||||||
|
|
@ -1761,6 +1822,8 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
query (str): The query to search for
|
query (str): The query to search for
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
logging.info(f"search_web: {engine} query: {query}")
|
||||||
|
|
||||||
# TODO: add playwright to search the web
|
# TODO: add playwright to search the web
|
||||||
if engine == "searxng":
|
if engine == "searxng":
|
||||||
if request.app.state.config.SEARXNG_QUERY_URL:
|
if request.app.state.config.SEARXNG_QUERY_URL:
|
||||||
|
|
@ -2001,11 +2064,13 @@ async def process_web_search(
|
||||||
)
|
)
|
||||||
|
|
||||||
search_tasks = [
|
search_tasks = [
|
||||||
run_in_threadpool(
|
asyncio.ensure_future(
|
||||||
search_web,
|
run_in_threadpool(
|
||||||
request,
|
search_web,
|
||||||
request.app.state.config.WEB_SEARCH_ENGINE,
|
request,
|
||||||
query,
|
request.app.state.config.WEB_SEARCH_ENGINE,
|
||||||
|
query,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for query in form_data.queries
|
for query in form_data.queries
|
||||||
]
|
]
|
||||||
|
|
@ -2014,6 +2079,7 @@ async def process_web_search(
|
||||||
|
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
if result:
|
if result:
|
||||||
|
result = await result
|
||||||
for item in result:
|
for item in result:
|
||||||
if item and item.link:
|
if item and item.link:
|
||||||
result_items.append(item)
|
result_items.append(item)
|
||||||
|
|
|
||||||
|
|
@ -602,7 +602,11 @@ class OAuthManager:
|
||||||
or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
|
or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
|
||||||
):
|
):
|
||||||
user_data: UserInfo = await client.userinfo(token=token)
|
user_data: UserInfo = await client.userinfo(token=token)
|
||||||
if provider == "feishu" and isinstance(user_data, dict) and "data" in user_data:
|
if (
|
||||||
|
provider == "feishu"
|
||||||
|
and isinstance(user_data, dict)
|
||||||
|
and "data" in user_data
|
||||||
|
):
|
||||||
user_data = user_data["data"]
|
user_data = user_data["data"]
|
||||||
if not user_data:
|
if not user_data:
|
||||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue