diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index 63bae5f33a..d7a200ff20 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -82,29 +82,32 @@ handle_peewee_migration(DATABASE_URL) SQLALCHEMY_DATABASE_URL = DATABASE_URL # Handle SQLCipher URLs -if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'): +if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): database_password = os.environ.get("DATABASE_PASSWORD") if not database_password or database_password.strip() == "": - raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs") - + raise ValueError( + "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" + ) + # Extract database path from SQLCipher URL - db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '') - if db_path.startswith('/'): + db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "") + if db_path.startswith("/"): db_path = db_path[1:] # Remove leading slash for relative paths - + # Create a custom creator function that uses sqlcipher3 def create_sqlcipher_connection(): import sqlcipher3 + conn = sqlcipher3.connect(db_path, check_same_thread=False) conn.execute(f"PRAGMA key = '{database_password}'") return conn - + engine = create_engine( "sqlite://", # Dummy URL since we're using creator creator=create_sqlcipher_connection, - echo=False + echo=False, ) - + log.info("Connected to encrypted SQLite database using SQLCipher") elif "sqlite" in SQLALCHEMY_DATABASE_URL: diff --git a/backend/open_webui/internal/wrappers.py b/backend/open_webui/internal/wrappers.py index 4e62ea4c5b..944bf2d952 100644 --- a/backend/open_webui/internal/wrappers.py +++ b/backend/open_webui/internal/wrappers.py @@ -46,23 +46,25 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): def register_connection(db_url): # Check if using SQLCipher protocol - if db_url.startswith('sqlite+sqlcipher://'): + if db_url.startswith("sqlite+sqlcipher://"): database_password = os.environ.get("DATABASE_PASSWORD") if not database_password or database_password.strip() == "": - raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs") - + raise ValueError( + "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" + ) + # Parse the database path from SQLCipher URL # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite - db_path = db_url.replace('sqlite+sqlcipher://', '') - if db_path.startswith('/'): + db_path = db_url.replace("sqlite+sqlcipher://", "") + if db_path.startswith("/"): db_path = db_path[1:] # Remove leading slash for relative paths - + # Use Peewee's native SqlCipherDatabase with encryption db = SqlCipherDatabase(db_path, passphrase=database_password) db.autoconnect = True db.reuse_if_open = True log.info("Connected to encrypted SQLite database using SQLCipher") - + else: # Standard database connection (existing logic) db = connect(db_url, unquote_user=True, unquote_password=True) diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index 2a3cd469b1..7db9251282 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -63,26 +63,29 @@ def run_migrations_online() -> None: """ # Handle SQLCipher URLs - if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'): + if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"): if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": - raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs") - + raise ValueError( + "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" + ) + # Extract database path from SQLCipher URL - db_path = DB_URL.replace('sqlite+sqlcipher://', '') - if db_path.startswith('/'): + db_path = DB_URL.replace("sqlite+sqlcipher://", "") + if db_path.startswith("/"): db_path = db_path[1:] # Remove leading slash for relative paths - + # Create a custom creator function that uses sqlcipher3 def create_sqlcipher_connection(): import sqlcipher3 + conn = sqlcipher3.connect(db_path, check_same_thread=False) conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'") return conn - + connectable = create_engine( "sqlite://", # Dummy URL since we're using creator creator=create_sqlcipher_connection, - echo=False + echo=False, ) else: # Standard database connection (existing logic) diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index c1e120a4da..9deb61f5a3 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -421,7 +421,7 @@ class PgvectorClient(VectorDBBase): documents[qid].append(row.text) metadatas[qid].append(row.vmetadata) - self.session.rollback() # read-only transaction + self.session.rollback() # read-only transaction return SearchResult( ids=ids, distances=distances, documents=documents, metadatas=metadatas ) @@ -479,7 +479,7 @@ class PgvectorClient(VectorDBBase): documents = [[result.text for result in results]] metadatas = [[result.vmetadata for result in results]] - self.session.rollback() # read-only transaction + self.session.rollback() # read-only transaction return GetResult( ids=ids, documents=documents, @@ -527,7 +527,7 @@ class PgvectorClient(VectorDBBase): documents = [[result.text for result in results]] metadatas = [[result.vmetadata for result in results]] - self.session.rollback() # read-only transaction + self.session.rollback() # read-only transaction return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() @@ -598,7 +598,7 @@ class PgvectorClient(VectorDBBase): .first() is not None ) - self.session.rollback() # read-only transaction + self.session.rollback() # read-only transaction return exists except Exception as e: self.session.rollback() diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index 7997865467..ea43297499 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -60,7 +60,11 @@ class QdrantClient(VectorDBBase): timeout=self.QDRANT_TIMEOUT, ) else: - self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=QDRANT_TIMEOUT,) + self.client = Qclient( + url=self.QDRANT_URI, + api_key=self.QDRANT_API_KEY, + timeout=QDRANT_TIMEOUT, + ) def _result_to_get_result(self, points) -> GetResult: ids = [] diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index c120f07632..ed4a8bab34 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -76,7 +76,11 @@ class QdrantClient(VectorDBBase): timeout=self.QDRANT_TIMEOUT, ) if self.PREFER_GRPC - else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=self.QDRANT_TIMEOUT,) + else Qclient( + url=self.QDRANT_URI, + api_key=self.QDRANT_API_KEY, + timeout=self.QDRANT_TIMEOUT, + ) ) # Main collection types for multi-tenancy