feat:补充额度相关数据库字段

This commit is contained in:
sylarchen1389 2025-12-07 12:26:37 +08:00
parent 5538cc1953
commit 01cceea3ae
6 changed files with 777 additions and 1 deletions

View file

@ -0,0 +1,72 @@
"""Add precharge fields to billing_log
Revision ID: 240e45fa2f01
Revises: f8c9d0e4a3b2
Create Date: 2025-12-06 18:00:00.000000
添加预扣费相关字段到billing_log表
- precharge_id: 预扣费事务IDUUID
- status: 记录状态precharge | settled | refunded
- estimated_tokens: 预估tokens总数
- refund_amount: 退款金额
"""
from alembic import op
import sqlalchemy as sa
revision = "240e45fa2f01"
down_revision = "f8c9d0e4a3b2"
branch_labels = None
depends_on = None
def upgrade():
"""升级数据库:添加预扣费字段"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
if is_sqlite:
# SQLite: 使用batch模式
with op.batch_alter_table("billing_log") as batch_op:
batch_op.add_column(sa.Column("precharge_id", sa.String(), nullable=True))
batch_op.add_column(
sa.Column("status", sa.String(20), nullable=True, server_default="settled")
)
batch_op.add_column(sa.Column("estimated_tokens", sa.Integer(), nullable=True))
batch_op.add_column(sa.Column("refund_amount", sa.Integer(), nullable=True))
batch_op.create_index("ix_billing_log_precharge_id", ["precharge_id"])
else:
# PostgreSQL: 直接操作
op.add_column("billing_log", sa.Column("precharge_id", sa.String(), nullable=True))
op.add_column(
"billing_log",
sa.Column("status", sa.String(20), nullable=True, server_default="settled"),
)
op.add_column(
"billing_log", sa.Column("estimated_tokens", sa.Integer(), nullable=True)
)
op.add_column(
"billing_log", sa.Column("refund_amount", sa.Integer(), nullable=True)
)
op.create_index("ix_billing_log_precharge_id", "billing_log", ["precharge_id"])
def downgrade():
"""降级数据库:删除预扣费字段"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
if is_sqlite:
with op.batch_alter_table("billing_log") as batch_op:
batch_op.drop_index("ix_billing_log_precharge_id")
batch_op.drop_column("refund_amount")
batch_op.drop_column("estimated_tokens")
batch_op.drop_column("status")
batch_op.drop_column("precharge_id")
else:
op.drop_index("ix_billing_log_precharge_id", table_name="billing_log")
op.drop_column("billing_log", "refund_amount")
op.drop_column("billing_log", "estimated_tokens")
op.drop_column("billing_log", "status")
op.drop_column("billing_log", "precharge_id")

View file

@ -0,0 +1,27 @@
"""merge billing and reply_to heads
Revision ID: 607801a77d0d
Revises: a5c220713937, e5f8a9b3c2d1
Create Date: 2025-12-05 03:12:14.859612
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = '607801a77d0d'
down_revision: Union[str, None] = ('a5c220713937', 'e5f8a9b3c2d1')
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
pass
def downgrade() -> None:
pass

View file

@ -0,0 +1,176 @@
"""Add billing system
Revision ID: e5f8a9b3c2d1
Revises: d31026856c01
Create Date: 2025-12-05 10:00:00.000000
添加计费模块相关表和字段
- user 表新增 balance, total_consumed, billing_status 字段
- 新增 model_pricing 模型定价
- 新增 billing_log 计费日志
- 新增 recharge_log 充值日志
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "e5f8a9b3c2d1"
down_revision = "d31026856c01"
branch_labels = None
depends_on = None
def upgrade():
"""升级数据库:添加计费系统"""
# 检查数据库类型
connection = op.get_bind()
is_postgresql = connection.dialect.name == "postgresql"
# 根据数据库类型选择NUMERIC类型
numeric_type = postgresql.NUMERIC(20, 6) if is_postgresql else sa.REAL
# 1. 修改 user 表:新增计费字段
op.add_column(
"user",
sa.Column(
"balance",
numeric_type,
server_default="0",
nullable=False,
),
)
op.add_column(
"user",
sa.Column(
"total_consumed",
numeric_type,
server_default="0",
nullable=False,
),
)
op.add_column(
"user",
sa.Column(
"billing_status",
sa.String(20),
server_default="active",
nullable=False,
),
)
# 2. 创建 model_pricing 表
op.create_table(
"model_pricing",
sa.Column("id", sa.String(), nullable=False),
sa.Column("model_id", sa.String(), nullable=False),
sa.Column(
"input_price",
numeric_type,
nullable=False,
),
sa.Column(
"output_price",
numeric_type,
nullable=False,
),
sa.Column(
"enabled",
sa.Boolean(),
server_default="true" if is_postgresql else "1",
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("model_id", name="uq_model_pricing_model_id"),
)
# 3. 创建 billing_log 表
op.create_table(
"billing_log",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("model_id", sa.String(), nullable=False),
sa.Column(
"prompt_tokens", sa.Integer(), server_default="0"
),
sa.Column(
"completion_tokens", sa.Integer(), server_default="0"
),
sa.Column(
"total_cost",
numeric_type,
nullable=False,
),
sa.Column(
"balance_after",
numeric_type,
nullable=True,
),
sa.Column(
"log_type",
sa.String(20),
server_default="deduct",
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# 创建 billing_log 索引
op.create_index(
"idx_billing_log_user_id", "billing_log", ["user_id"], unique=False
)
op.create_index(
"idx_billing_log_created_at", "billing_log", ["created_at"], unique=False
)
op.create_index(
"idx_billing_log_user_created",
"billing_log",
["user_id", "created_at"],
unique=False,
)
# 4. 创建 recharge_log 表
op.create_table(
"recharge_log",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column(
"amount",
numeric_type,
nullable=False,
),
sa.Column("operator_id", sa.String(), nullable=False),
sa.Column("remark", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# 创建 recharge_log 索引
op.create_index(
"idx_recharge_log_user_id", "recharge_log", ["user_id"], unique=False
)
op.create_index(
"idx_recharge_log_created_at", "recharge_log", ["created_at"], unique=False
)
def downgrade():
"""降级数据库:移除计费系统"""
# 删除索引
op.drop_index("idx_recharge_log_created_at", "recharge_log")
op.drop_index("idx_recharge_log_user_id", "recharge_log")
op.drop_index("idx_billing_log_user_created", "billing_log")
op.drop_index("idx_billing_log_created_at", "billing_log")
op.drop_index("idx_billing_log_user_id", "billing_log")
# 删除表
op.drop_table("recharge_log")
op.drop_table("billing_log")
op.drop_table("model_pricing")
# 删除 user 表字段
op.drop_column("user", "billing_status")
op.drop_column("user", "total_consumed")
op.drop_column("user", "balance")

View file

@ -0,0 +1,210 @@
"""Change billing amounts to integer (milli-yuan precision)
Revision ID: f8c9d0e4a3b2
Revises: e5f8a9b3c2d1
Create Date: 2025-12-06 12:00:00.000000
将金额字段从 DECIMAL/REAL 改为 INTEGER以毫为单位存储1 = 10000精度0.0001
- user.balance: Decimal -> Integer ()
- user.total_consumed: Decimal -> Integer ()
- model_pricing.input_price: Decimal -> Integer (/百万tokens)
- model_pricing.output_price: Decimal -> Integer (/百万tokens)
- billing_log.total_cost: Decimal -> Integer ()
- billing_log.balance_after: Decimal -> Integer ()
- recharge_log.amount: Decimal -> Integer ()
注意现有数据将乘以10000转换 ->
"""
from alembic import op
import sqlalchemy as sa
revision = "f8c9d0e4a3b2"
down_revision = "607801a77d0d" # merge billing and reply_to heads
branch_labels = None
depends_on = None
def upgrade():
"""升级数据库:将金额字段改为整数(分)"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
if is_sqlite:
# SQLite 不支持直接修改列类型,需要重建表
# 1. user 表
with op.batch_alter_table("user") as batch_op:
# SQLite: 先转换现有数据(元 * 10000 = 毫)
connection.execute(sa.text("""
UPDATE user
SET balance = CAST(balance * 10000 AS INTEGER),
total_consumed = CAST(total_consumed * 10000 AS INTEGER)
"""))
# 修改列类型
batch_op.alter_column("balance",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False,
existing_server_default="0")
batch_op.alter_column("total_consumed",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False,
existing_server_default="0")
# 2. model_pricing 表
with op.batch_alter_table("model_pricing") as batch_op:
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = CAST(input_price * 10000 AS INTEGER),
output_price = CAST(output_price * 10000 AS INTEGER)
"""))
batch_op.alter_column("input_price",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
batch_op.alter_column("output_price",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
# 3. billing_log 表
with op.batch_alter_table("billing_log") as batch_op:
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = CAST(total_cost * 10000 AS INTEGER),
balance_after = CAST(COALESCE(balance_after, 0) * 10000 AS INTEGER)
"""))
batch_op.alter_column("total_cost",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
batch_op.alter_column("balance_after",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=True)
# 4. recharge_log 表
with op.batch_alter_table("recharge_log") as batch_op:
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = CAST(amount * 10000 AS INTEGER)
"""))
batch_op.alter_column("amount",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
else:
# PostgreSQL 支持直接修改
# 1. 先转换数据(元 * 10000 = 毫)
connection.execute(sa.text("""
UPDATE "user"
SET balance = CAST(balance * 10000 AS INTEGER),
total_consumed = CAST(total_consumed * 10000 AS INTEGER)
"""))
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = CAST(input_price * 10000 AS INTEGER),
output_price = CAST(output_price * 10000 AS INTEGER)
"""))
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = CAST(total_cost * 10000 AS INTEGER),
balance_after = CAST(COALESCE(balance_after, 0) * 10000 AS INTEGER)
"""))
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = CAST(amount * 10000 AS INTEGER)
"""))
# 2. 修改列类型
op.alter_column("user", "balance", type_=sa.Integer())
op.alter_column("user", "total_consumed", type_=sa.Integer())
op.alter_column("model_pricing", "input_price", type_=sa.Integer())
op.alter_column("model_pricing", "output_price", type_=sa.Integer())
op.alter_column("billing_log", "total_cost", type_=sa.Integer())
op.alter_column("billing_log", "balance_after", type_=sa.Integer())
op.alter_column("recharge_log", "amount", type_=sa.Integer())
def downgrade():
"""降级数据库:将整数(分)改回 Decimal"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
is_postgresql = connection.dialect.name == "postgresql"
numeric_type = sa.NUMERIC(20, 6) if is_postgresql else sa.REAL
if is_sqlite:
# SQLite 降级
with op.batch_alter_table("user") as batch_op:
batch_op.alter_column("balance", type_=sa.REAL)
batch_op.alter_column("total_consumed", type_=sa.REAL)
with op.batch_alter_table("model_pricing") as batch_op:
batch_op.alter_column("input_price", type_=sa.REAL)
batch_op.alter_column("output_price", type_=sa.REAL)
with op.batch_alter_table("billing_log") as batch_op:
batch_op.alter_column("total_cost", type_=sa.REAL)
batch_op.alter_column("balance_after", type_=sa.REAL)
with op.batch_alter_table("recharge_log") as batch_op:
batch_op.alter_column("amount", type_=sa.REAL)
# 转换数据(毫 / 10000 = 元)
connection.execute(sa.text("""
UPDATE user
SET balance = CAST(balance AS REAL) / 10000.0,
total_consumed = CAST(total_consumed AS REAL) / 10000.0
"""))
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = CAST(input_price AS REAL) / 10000.0,
output_price = CAST(output_price AS REAL) / 10000.0
"""))
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = CAST(total_cost AS REAL) / 10000.0,
balance_after = CAST(balance_after AS REAL) / 10000.0
"""))
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = CAST(amount AS REAL) / 10000.0
"""))
else:
# PostgreSQL 降级
op.alter_column("user", "balance", type_=numeric_type)
op.alter_column("user", "total_consumed", type_=numeric_type)
op.alter_column("model_pricing", "input_price", type_=numeric_type)
op.alter_column("model_pricing", "output_price", type_=numeric_type)
op.alter_column("billing_log", "total_cost", type_=numeric_type)
op.alter_column("billing_log", "balance_after", type_=numeric_type)
op.alter_column("recharge_log", "amount", type_=numeric_type)
connection.execute(sa.text("""
UPDATE "user"
SET balance = balance / 10000.0,
total_consumed = total_consumed / 10000.0
"""))
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = input_price / 10000.0,
output_price = output_price / 10000.0
"""))
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = total_cost / 10000.0,
balance_after = balance_after / 10000.0
"""))
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = amount / 10000.0
"""))

View file

@ -0,0 +1,276 @@
"""
计费模块数据模型
包含模型定价计费日志充值日志的 ORM 模型和数据访问层
"""
import time
import uuid
from typing import Optional
from pydantic import BaseModel, ConfigDict
from sqlalchemy import Boolean, Column, String, Integer, BigInteger, Text
from open_webui.internal.db import Base, get_db
####################
# ModelPricing DB Schema
####################
class ModelPricing(Base):
"""模型定价表"""
__tablename__ = "model_pricing"
id = Column(String, primary_key=True)
model_id = Column(String, unique=True, nullable=False) # 模型标识,如 "gpt-4o"
input_price = Column(Integer, nullable=False) # 输入价格(分/1k token
output_price = Column(Integer, nullable=False) # 输出价格(分/1k token
enabled = Column(Boolean, default=True, nullable=False) # 是否启用
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
class BillingLog(Base):
"""计费日志表"""
__tablename__ = "billing_log"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False, index=True)
model_id = Column(String, nullable=False)
prompt_tokens = Column(Integer, default=0)
completion_tokens = Column(Integer, default=0)
total_cost = Column(Integer, nullable=False) # 本次费用(分)
balance_after = Column(Integer) # 扣费后余额(分)
log_type = Column(String(20), default="deduct") # deduct/refund/precharge/settle
created_at = Column(BigInteger, nullable=False, index=True)
# 预扣费相关字段
precharge_id = Column(String, nullable=True, index=True) # 预扣费事务ID
status = Column(String(20), nullable=True, default="settled") # precharge | settled | refunded
estimated_tokens = Column(Integer, nullable=True) # 预估tokens总数
refund_amount = Column(Integer, nullable=True) # 退款金额(毫)
class RechargeLog(Base):
"""充值日志表"""
__tablename__ = "recharge_log"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False, index=True)
amount = Column(Integer, nullable=False) # 充值金额(分)
operator_id = Column(String, nullable=False) # 操作员ID
remark = Column(Text) # 备注
created_at = Column(BigInteger, nullable=False)
####################
# Pydantic Models
####################
class ModelPricingModel(BaseModel):
"""模型定价 Pydantic 模型(以分为单位)"""
id: str
model_id: str
input_price: int # 分/1k tokens
output_price: int # 分/1k tokens
enabled: bool
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
class BillingLogModel(BaseModel):
"""计费日志 Pydantic 模型(以分为单位)"""
id: str
user_id: str
model_id: str
prompt_tokens: int
completion_tokens: int
total_cost: int # 分
balance_after: Optional[int] # 分
log_type: str
created_at: int
# 预扣费相关字段
precharge_id: Optional[str] = None
status: Optional[str] = "settled"
estimated_tokens: Optional[int] = None
refund_amount: Optional[int] = None
model_config = ConfigDict(from_attributes=True)
class RechargeLogModel(BaseModel):
"""充值日志 Pydantic 模型(以分为单位)"""
id: str
user_id: str
amount: int # 分
operator_id: str
remark: Optional[str]
created_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Data Access Layer
####################
class ModelPricingTable:
"""模型定价数据访问层"""
def get_by_model_id(self, model_id: str) -> Optional[ModelPricingModel]:
"""根据模型ID获取定价"""
try:
with get_db() as db:
pricing = (
db.query(ModelPricing)
.filter_by(model_id=model_id, enabled=True)
.first()
)
return ModelPricingModel.model_validate(pricing) if pricing else None
except Exception:
return None
def get_all(self) -> list[ModelPricingModel]:
"""获取所有定价"""
with get_db() as db:
pricings = db.query(ModelPricing).filter_by(enabled=True).all()
return [ModelPricingModel.model_validate(p) for p in pricings]
def upsert(
self, model_id: str, input_price: int, output_price: int
) -> ModelPricingModel:
"""创建或更新定价"""
with get_db() as db:
existing = db.query(ModelPricing).filter_by(model_id=model_id).first()
now = int(time.time())
if existing:
# 更新
existing.input_price = input_price
existing.output_price = output_price
existing.updated_at = now
db.commit()
db.refresh(existing)
return ModelPricingModel.model_validate(existing)
else:
# 创建
new_pricing = ModelPricing(
id=str(uuid.uuid4()),
model_id=model_id,
input_price=input_price,
output_price=output_price,
enabled=True,
created_at=now,
updated_at=now,
)
db.add(new_pricing)
db.commit()
db.refresh(new_pricing)
return ModelPricingModel.model_validate(new_pricing)
def delete_by_model_id(self, model_id: str) -> bool:
"""删除定价(软删除,设置 enabled=False"""
try:
with get_db() as db:
result = (
db.query(ModelPricing)
.filter_by(model_id=model_id)
.update({"enabled": False})
)
db.commit()
return result > 0
except Exception:
return False
class BillingLogTable:
"""计费日志数据访问层"""
def get_by_user_id(
self, user_id: str, limit: int = 50, offset: int = 0
) -> list[BillingLogModel]:
"""获取用户计费日志"""
with get_db() as db:
logs = (
db.query(BillingLog)
.filter_by(user_id=user_id)
.order_by(BillingLog.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
return [BillingLogModel.model_validate(log) for log in logs]
def count_by_user_id(self, user_id: str) -> int:
"""统计用户日志数量"""
with get_db() as db:
return db.query(BillingLog).filter_by(user_id=user_id).count()
class RechargeLogTable:
"""充值日志数据访问层"""
def get_by_user_id(
self, user_id: str, limit: int = 50, offset: int = 0
) -> list[RechargeLogModel]:
"""获取用户充值日志"""
with get_db() as db:
logs = (
db.query(RechargeLog)
.filter_by(user_id=user_id)
.order_by(RechargeLog.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
return [RechargeLogModel.model_validate(log) for log in logs]
def get_by_user_id_with_operator_name(
self, user_id: str, limit: int = 50, offset: int = 0
) -> list[dict]:
"""获取用户充值日志,包含操作员姓名"""
from open_webui.models.users import User
with get_db() as db:
logs = (
db.query(RechargeLog, User.name.label("operator_name"))
.join(User, RechargeLog.operator_id == User.id)
.filter(RechargeLog.user_id == user_id)
.order_by(RechargeLog.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
# 转换为字典格式
return [
{
"id": log.RechargeLog.id,
"user_id": log.RechargeLog.user_id,
"amount": log.RechargeLog.amount, # 整数(分)
"operator_id": log.RechargeLog.operator_id,
"operator_name": log.operator_name,
"remark": log.RechargeLog.remark,
"created_at": log.RechargeLog.created_at,
}
for log in logs
]
# 单例实例
ModelPricings = ModelPricingTable()
BillingLogs = BillingLogTable()
RechargeLogs = RechargeLogTable()

View file

@ -11,7 +11,7 @@ from open_webui.utils.misc import throttle
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Date from sqlalchemy import BigInteger, Column, String, Text, Date, Integer
from sqlalchemy import or_ from sqlalchemy import or_
import datetime import datetime
@ -48,6 +48,11 @@ class User(Base):
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
created_at = Column(BigInteger) created_at = Column(BigInteger)
# 计费相关字段(以分为单位存储)
balance = Column(Integer, default=0, nullable=False) # 账户余额(分)
total_consumed = Column(Integer, default=0, nullable=False) # 累计消费(分)
billing_status = Column(String(20), default="active", nullable=False) # active/frozen
class UserSettings(BaseModel): class UserSettings(BaseModel):
ui: Optional[dict] = {} ui: Optional[dict] = {}
@ -79,6 +84,11 @@ class UserModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
# 计费相关字段(以分为单位)
balance: Optional[int] = 0
total_consumed: Optional[int] = 0
billing_status: Optional[str] = "active"
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -272,6 +282,11 @@ class UsersTable:
query = query.order_by(User.role.asc()) query = query.order_by(User.role.asc())
else: else:
query = query.order_by(User.role.desc()) query = query.order_by(User.role.desc())
elif order_by == "balance":
if direction == "asc":
query = query.order_by(User.balance.asc())
else:
query = query.order_by(User.balance.desc())
else: else:
query = query.order_by(User.created_at.desc()) query = query.order_by(User.created_at.desc())