diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 0f603c92d9..19d09a62d3 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -339,6 +339,19 @@ else: except Exception: DATABASE_POOL_RECYCLE = 3600 +DATABASE_ENABLE_SQLITE_WAL = (os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true") + +DATABASE_DEDUPLICATE_INTERVAL = ( + os.environ.get("DATABASE_DEDUPLICATE_INTERVAL", 0.) +) +if DATABASE_DEDUPLICATE_INTERVAL == "": + DATABASE_DEDUPLICATE_INTERVAL = 0.0 +else: + try: + DATABASE_DEDUPLICATE_INTERVAL = float(DATABASE_DEDUPLICATE_INTERVAL) + except Exception: + DATABASE_DEDUPLICATE_INTERVAL = 0.0 + RESET_CONFIG_ON_START = ( os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" ) diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index d7a200ff20..b6913d87b0 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -14,9 +14,10 @@ from open_webui.env import ( DATABASE_POOL_RECYCLE, DATABASE_POOL_SIZE, DATABASE_POOL_TIMEOUT, + DATABASE_ENABLE_SQLITE_WAL, ) from peewee_migrate import Router -from sqlalchemy import Dialect, create_engine, MetaData, types +from sqlalchemy import Dialect, create_engine, MetaData, event, types from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import QueuePool, NullPool @@ -114,6 +115,16 @@ elif "sqlite" in SQLALCHEMY_DATABASE_URL: engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) + + def on_connect(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + if DATABASE_ENABLE_SQLITE_WAL: + cursor.execute("PRAGMA journal_mode=WAL") + else: + cursor.execute("PRAGMA journal_mode=DELETE") + cursor.close() + + event.listen(engine, "connect", on_connect) else: if isinstance(DATABASE_POOL_SIZE, int): if DATABASE_POOL_SIZE > 0: diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 60b6ad0c10..47cd7b0eb0 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -4,8 +4,10 @@ from typing import Optional from open_webui.internal.db import Base, JSONField, get_db +from open_webui.env import DATABASE_DEDUPLICATE_INTERVAL from open_webui.models.chats import Chats from open_webui.models.groups import Groups +from open_webui.utils.misc import deduplicate from pydantic import BaseModel, ConfigDict @@ -311,6 +313,7 @@ class UsersTable: except Exception: return None + @deduplicate(DATABASE_DEDUPLICATE_INTERVAL) def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 2a780209a7..e7a007df38 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -1,5 +1,6 @@ import hashlib import re +import threading import time import uuid import logging @@ -478,3 +479,43 @@ def convert_logit_bias_input_to_json(user_input): bias = 100 if bias > 100 else -100 if bias < -100 else bias logit_bias_json[token] = bias return json.dumps(logit_bias_json) + + +def freeze(value): + """ + Freeze a value to make it hashable. + """ + if isinstance(value, dict): + return frozenset((k, freeze(v)) for k, v in value.items()) + elif isinstance(value, list): + return tuple(freeze(v) for v in value) + return value + + +def deduplicate(interval: float = 10.0): + """ + Decorator to prevent a function from being called more than once within a specified duration. + If the function is called again within the duration, it returns None. To avoid returning + different types, the return type of the function should be Optional[T]. + + :param interval: Duration in seconds to wait before allowing the function to be called again. + """ + + def decorator(func): + last_calls = {} + lock = threading.Lock() + + def wrapper(*args, **kwargs): + key = (args, freeze(kwargs)) + now = time.time() + if now - last_calls.get(key, 0) < interval: + return None + with lock: + if now - last_calls.get(key, 0) < interval: + return None + last_calls[key] = now + return func(*args, **kwargs) + + return wrapper + + return decorator