mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
chore: format
This commit is contained in:
parent
f85aaa4ed9
commit
77189664c2
6 changed files with 46 additions and 30 deletions
|
|
@ -82,29 +82,32 @@ handle_peewee_migration(DATABASE_URL)
|
||||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||||
|
|
||||||
# Handle SQLCipher URLs
|
# Handle SQLCipher URLs
|
||||||
if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
|
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||||
if not database_password or database_password.strip() == "":
|
if not database_password or database_password.strip() == "":
|
||||||
raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs")
|
raise ValueError(
|
||||||
|
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||||
|
)
|
||||||
|
|
||||||
# Extract database path from SQLCipher URL
|
# Extract database path from SQLCipher URL
|
||||||
db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '')
|
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
||||||
if db_path.startswith('/'):
|
if db_path.startswith("/"):
|
||||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
# Create a custom creator function that uses sqlcipher3
|
# Create a custom creator function that uses sqlcipher3
|
||||||
def create_sqlcipher_connection():
|
def create_sqlcipher_connection():
|
||||||
import sqlcipher3
|
import sqlcipher3
|
||||||
|
|
||||||
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
||||||
conn.execute(f"PRAGMA key = '{database_password}'")
|
conn.execute(f"PRAGMA key = '{database_password}'")
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
"sqlite://", # Dummy URL since we're using creator
|
"sqlite://", # Dummy URL since we're using creator
|
||||||
creator=create_sqlcipher_connection,
|
creator=create_sqlcipher_connection,
|
||||||
echo=False
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||||
|
|
||||||
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
|
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||||
|
|
|
||||||
|
|
@ -46,23 +46,25 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||||
|
|
||||||
def register_connection(db_url):
|
def register_connection(db_url):
|
||||||
# Check if using SQLCipher protocol
|
# Check if using SQLCipher protocol
|
||||||
if db_url.startswith('sqlite+sqlcipher://'):
|
if db_url.startswith("sqlite+sqlcipher://"):
|
||||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||||
if not database_password or database_password.strip() == "":
|
if not database_password or database_password.strip() == "":
|
||||||
raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs")
|
raise ValueError(
|
||||||
|
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||||
|
)
|
||||||
|
|
||||||
# Parse the database path from SQLCipher URL
|
# Parse the database path from SQLCipher URL
|
||||||
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
||||||
db_path = db_url.replace('sqlite+sqlcipher://', '')
|
db_path = db_url.replace("sqlite+sqlcipher://", "")
|
||||||
if db_path.startswith('/'):
|
if db_path.startswith("/"):
|
||||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
# Use Peewee's native SqlCipherDatabase with encryption
|
# Use Peewee's native SqlCipherDatabase with encryption
|
||||||
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
||||||
db.autoconnect = True
|
db.autoconnect = True
|
||||||
db.reuse_if_open = True
|
db.reuse_if_open = True
|
||||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard database connection (existing logic)
|
# Standard database connection (existing logic)
|
||||||
db = connect(db_url, unquote_user=True, unquote_password=True)
|
db = connect(db_url, unquote_user=True, unquote_password=True)
|
||||||
|
|
|
||||||
|
|
@ -63,26 +63,29 @@ def run_migrations_online() -> None:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Handle SQLCipher URLs
|
# Handle SQLCipher URLs
|
||||||
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
|
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
|
||||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
|
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
|
||||||
raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs")
|
raise ValueError(
|
||||||
|
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||||
|
)
|
||||||
|
|
||||||
# Extract database path from SQLCipher URL
|
# Extract database path from SQLCipher URL
|
||||||
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
|
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
|
||||||
if db_path.startswith('/'):
|
if db_path.startswith("/"):
|
||||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
# Create a custom creator function that uses sqlcipher3
|
# Create a custom creator function that uses sqlcipher3
|
||||||
def create_sqlcipher_connection():
|
def create_sqlcipher_connection():
|
||||||
import sqlcipher3
|
import sqlcipher3
|
||||||
|
|
||||||
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
||||||
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
|
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
connectable = create_engine(
|
connectable = create_engine(
|
||||||
"sqlite://", # Dummy URL since we're using creator
|
"sqlite://", # Dummy URL since we're using creator
|
||||||
creator=create_sqlcipher_connection,
|
creator=create_sqlcipher_connection,
|
||||||
echo=False
|
echo=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Standard database connection (existing logic)
|
# Standard database connection (existing logic)
|
||||||
|
|
|
||||||
|
|
@ -421,7 +421,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
documents[qid].append(row.text)
|
documents[qid].append(row.text)
|
||||||
metadatas[qid].append(row.vmetadata)
|
metadatas[qid].append(row.vmetadata)
|
||||||
|
|
||||||
self.session.rollback() # read-only transaction
|
self.session.rollback() # read-only transaction
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||||
)
|
)
|
||||||
|
|
@ -479,7 +479,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
documents = [[result.text for result in results]]
|
documents = [[result.text for result in results]]
|
||||||
metadatas = [[result.vmetadata for result in results]]
|
metadatas = [[result.vmetadata for result in results]]
|
||||||
|
|
||||||
self.session.rollback() # read-only transaction
|
self.session.rollback() # read-only transaction
|
||||||
return GetResult(
|
return GetResult(
|
||||||
ids=ids,
|
ids=ids,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
|
@ -527,7 +527,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
documents = [[result.text for result in results]]
|
documents = [[result.text for result in results]]
|
||||||
metadatas = [[result.vmetadata for result in results]]
|
metadatas = [[result.vmetadata for result in results]]
|
||||||
|
|
||||||
self.session.rollback() # read-only transaction
|
self.session.rollback() # read-only transaction
|
||||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
|
|
@ -598,7 +598,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
.first()
|
.first()
|
||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
self.session.rollback() # read-only transaction
|
self.session.rollback() # read-only transaction
|
||||||
return exists
|
return exists
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,11 @@ class QdrantClient(VectorDBBase):
|
||||||
timeout=self.QDRANT_TIMEOUT,
|
timeout=self.QDRANT_TIMEOUT,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=QDRANT_TIMEOUT,)
|
self.client = Qclient(
|
||||||
|
url=self.QDRANT_URI,
|
||||||
|
api_key=self.QDRANT_API_KEY,
|
||||||
|
timeout=QDRANT_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
def _result_to_get_result(self, points) -> GetResult:
|
def _result_to_get_result(self, points) -> GetResult:
|
||||||
ids = []
|
ids = []
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,11 @@ class QdrantClient(VectorDBBase):
|
||||||
timeout=self.QDRANT_TIMEOUT,
|
timeout=self.QDRANT_TIMEOUT,
|
||||||
)
|
)
|
||||||
if self.PREFER_GRPC
|
if self.PREFER_GRPC
|
||||||
else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=self.QDRANT_TIMEOUT,)
|
else Qclient(
|
||||||
|
url=self.QDRANT_URI,
|
||||||
|
api_key=self.QDRANT_API_KEY,
|
||||||
|
timeout=self.QDRANT_TIMEOUT,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main collection types for multi-tenancy
|
# Main collection types for multi-tenancy
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue