chore: format

This commit is contained in:
Timothy Jaeryang Baek 2025-08-09 23:57:35 +04:00
parent f85aaa4ed9
commit 77189664c2
6 changed files with 46 additions and 30 deletions

View file

@ -82,19 +82,22 @@ handle_peewee_migration(DATABASE_URL)
SQLALCHEMY_DATABASE_URL = DATABASE_URL SQLALCHEMY_DATABASE_URL = DATABASE_URL
# Handle SQLCipher URLs # Handle SQLCipher URLs
if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'): if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
database_password = os.environ.get("DATABASE_PASSWORD") database_password = os.environ.get("DATABASE_PASSWORD")
if not database_password or database_password.strip() == "": if not database_password or database_password.strip() == "":
raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs") raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Extract database path from SQLCipher URL # Extract database path from SQLCipher URL
db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '') db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
if db_path.startswith('/'): if db_path.startswith("/"):
db_path = db_path[1:] # Remove leading slash for relative paths db_path = db_path[1:] # Remove leading slash for relative paths
# Create a custom creator function that uses sqlcipher3 # Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection(): def create_sqlcipher_connection():
import sqlcipher3 import sqlcipher3
conn = sqlcipher3.connect(db_path, check_same_thread=False) conn = sqlcipher3.connect(db_path, check_same_thread=False)
conn.execute(f"PRAGMA key = '{database_password}'") conn.execute(f"PRAGMA key = '{database_password}'")
return conn return conn
@ -102,7 +105,7 @@ if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
engine = create_engine( engine = create_engine(
"sqlite://", # Dummy URL since we're using creator "sqlite://", # Dummy URL since we're using creator
creator=create_sqlcipher_connection, creator=create_sqlcipher_connection,
echo=False echo=False,
) )
log.info("Connected to encrypted SQLite database using SQLCipher") log.info("Connected to encrypted SQLite database using SQLCipher")

View file

@ -46,15 +46,17 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url): def register_connection(db_url):
# Check if using SQLCipher protocol # Check if using SQLCipher protocol
if db_url.startswith('sqlite+sqlcipher://'): if db_url.startswith("sqlite+sqlcipher://"):
database_password = os.environ.get("DATABASE_PASSWORD") database_password = os.environ.get("DATABASE_PASSWORD")
if not database_password or database_password.strip() == "": if not database_password or database_password.strip() == "":
raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs") raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Parse the database path from SQLCipher URL # Parse the database path from SQLCipher URL
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
db_path = db_url.replace('sqlite+sqlcipher://', '') db_path = db_url.replace("sqlite+sqlcipher://", "")
if db_path.startswith('/'): if db_path.startswith("/"):
db_path = db_path[1:] # Remove leading slash for relative paths db_path = db_path[1:] # Remove leading slash for relative paths
# Use Peewee's native SqlCipherDatabase with encryption # Use Peewee's native SqlCipherDatabase with encryption

View file

@ -63,18 +63,21 @@ def run_migrations_online() -> None:
""" """
# Handle SQLCipher URLs # Handle SQLCipher URLs
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'): if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs") raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Extract database path from SQLCipher URL # Extract database path from SQLCipher URL
db_path = DB_URL.replace('sqlite+sqlcipher://', '') db_path = DB_URL.replace("sqlite+sqlcipher://", "")
if db_path.startswith('/'): if db_path.startswith("/"):
db_path = db_path[1:] # Remove leading slash for relative paths db_path = db_path[1:] # Remove leading slash for relative paths
# Create a custom creator function that uses sqlcipher3 # Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection(): def create_sqlcipher_connection():
import sqlcipher3 import sqlcipher3
conn = sqlcipher3.connect(db_path, check_same_thread=False) conn = sqlcipher3.connect(db_path, check_same_thread=False)
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'") conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
return conn return conn
@ -82,7 +85,7 @@ def run_migrations_online() -> None:
connectable = create_engine( connectable = create_engine(
"sqlite://", # Dummy URL since we're using creator "sqlite://", # Dummy URL since we're using creator
creator=create_sqlcipher_connection, creator=create_sqlcipher_connection,
echo=False echo=False,
) )
else: else:
# Standard database connection (existing logic) # Standard database connection (existing logic)

View file

@ -421,7 +421,7 @@ class PgvectorClient(VectorDBBase):
documents[qid].append(row.text) documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata) metadatas[qid].append(row.vmetadata)
self.session.rollback() # read-only transaction self.session.rollback() # read-only transaction
return SearchResult( return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas ids=ids, distances=distances, documents=documents, metadatas=metadatas
) )
@ -479,7 +479,7 @@ class PgvectorClient(VectorDBBase):
documents = [[result.text for result in results]] documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]] metadatas = [[result.vmetadata for result in results]]
self.session.rollback() # read-only transaction self.session.rollback() # read-only transaction
return GetResult( return GetResult(
ids=ids, ids=ids,
documents=documents, documents=documents,
@ -527,7 +527,7 @@ class PgvectorClient(VectorDBBase):
documents = [[result.text for result in results]] documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]] metadatas = [[result.vmetadata for result in results]]
self.session.rollback() # read-only transaction self.session.rollback() # read-only transaction
return GetResult(ids=ids, documents=documents, metadatas=metadatas) return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
@ -598,7 +598,7 @@ class PgvectorClient(VectorDBBase):
.first() .first()
is not None is not None
) )
self.session.rollback() # read-only transaction self.session.rollback() # read-only transaction
return exists return exists
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()

View file

@ -60,7 +60,11 @@ class QdrantClient(VectorDBBase):
timeout=self.QDRANT_TIMEOUT, timeout=self.QDRANT_TIMEOUT,
) )
else: else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=QDRANT_TIMEOUT,) self.client = Qclient(
url=self.QDRANT_URI,
api_key=self.QDRANT_API_KEY,
timeout=QDRANT_TIMEOUT,
)
def _result_to_get_result(self, points) -> GetResult: def _result_to_get_result(self, points) -> GetResult:
ids = [] ids = []

View file

@ -76,7 +76,11 @@ class QdrantClient(VectorDBBase):
timeout=self.QDRANT_TIMEOUT, timeout=self.QDRANT_TIMEOUT,
) )
if self.PREFER_GRPC if self.PREFER_GRPC
else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=self.QDRANT_TIMEOUT,) else Qclient(
url=self.QDRANT_URI,
api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
)
) )
# Main collection types for multi-tenancy # Main collection types for multi-tenancy