feat:补充额度相关数据库字段

This commit is contained in:
sylarchen1389 2025-12-07 12:26:37 +08:00
parent 5538cc1953
commit 01cceea3ae
6 changed files with 777 additions and 1 deletions

View file

@ -0,0 +1,72 @@
"""Add precharge fields to billing_log
Revision ID: 240e45fa2f01
Revises: f8c9d0e4a3b2
Create Date: 2025-12-06 18:00:00.000000
添加预扣费相关字段到billing_log表
- precharge_id: 预扣费事务IDUUID
- status: 记录状态precharge | settled | refunded
- estimated_tokens: 预估tokens总数
- refund_amount: 退款金额
"""
from alembic import op
import sqlalchemy as sa
revision = "240e45fa2f01"
down_revision = "f8c9d0e4a3b2"
branch_labels = None
depends_on = None
def upgrade():
"""升级数据库:添加预扣费字段"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
if is_sqlite:
# SQLite: 使用batch模式
with op.batch_alter_table("billing_log") as batch_op:
batch_op.add_column(sa.Column("precharge_id", sa.String(), nullable=True))
batch_op.add_column(
sa.Column("status", sa.String(20), nullable=True, server_default="settled")
)
batch_op.add_column(sa.Column("estimated_tokens", sa.Integer(), nullable=True))
batch_op.add_column(sa.Column("refund_amount", sa.Integer(), nullable=True))
batch_op.create_index("ix_billing_log_precharge_id", ["precharge_id"])
else:
# PostgreSQL: 直接操作
op.add_column("billing_log", sa.Column("precharge_id", sa.String(), nullable=True))
op.add_column(
"billing_log",
sa.Column("status", sa.String(20), nullable=True, server_default="settled"),
)
op.add_column(
"billing_log", sa.Column("estimated_tokens", sa.Integer(), nullable=True)
)
op.add_column(
"billing_log", sa.Column("refund_amount", sa.Integer(), nullable=True)
)
op.create_index("ix_billing_log_precharge_id", "billing_log", ["precharge_id"])
def downgrade():
"""降级数据库:删除预扣费字段"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
if is_sqlite:
with op.batch_alter_table("billing_log") as batch_op:
batch_op.drop_index("ix_billing_log_precharge_id")
batch_op.drop_column("refund_amount")
batch_op.drop_column("estimated_tokens")
batch_op.drop_column("status")
batch_op.drop_column("precharge_id")
else:
op.drop_index("ix_billing_log_precharge_id", table_name="billing_log")
op.drop_column("billing_log", "refund_amount")
op.drop_column("billing_log", "estimated_tokens")
op.drop_column("billing_log", "status")
op.drop_column("billing_log", "precharge_id")

View file

@ -0,0 +1,27 @@
"""merge billing and reply_to heads
Revision ID: 607801a77d0d
Revises: a5c220713937, e5f8a9b3c2d1
Create Date: 2025-12-05 03:12:14.859612
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = '607801a77d0d'
down_revision: Union[str, None] = ('a5c220713937', 'e5f8a9b3c2d1')
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
pass
def downgrade() -> None:
pass

View file

@ -0,0 +1,176 @@
"""Add billing system
Revision ID: e5f8a9b3c2d1
Revises: d31026856c01
Create Date: 2025-12-05 10:00:00.000000
添加计费模块相关表和字段
- user 表新增 balance, total_consumed, billing_status 字段
- 新增 model_pricing 模型定价
- 新增 billing_log 计费日志
- 新增 recharge_log 充值日志
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "e5f8a9b3c2d1"
down_revision = "d31026856c01"
branch_labels = None
depends_on = None
def upgrade():
"""升级数据库:添加计费系统"""
# 检查数据库类型
connection = op.get_bind()
is_postgresql = connection.dialect.name == "postgresql"
# 根据数据库类型选择NUMERIC类型
numeric_type = postgresql.NUMERIC(20, 6) if is_postgresql else sa.REAL
# 1. 修改 user 表:新增计费字段
op.add_column(
"user",
sa.Column(
"balance",
numeric_type,
server_default="0",
nullable=False,
),
)
op.add_column(
"user",
sa.Column(
"total_consumed",
numeric_type,
server_default="0",
nullable=False,
),
)
op.add_column(
"user",
sa.Column(
"billing_status",
sa.String(20),
server_default="active",
nullable=False,
),
)
# 2. 创建 model_pricing 表
op.create_table(
"model_pricing",
sa.Column("id", sa.String(), nullable=False),
sa.Column("model_id", sa.String(), nullable=False),
sa.Column(
"input_price",
numeric_type,
nullable=False,
),
sa.Column(
"output_price",
numeric_type,
nullable=False,
),
sa.Column(
"enabled",
sa.Boolean(),
server_default="true" if is_postgresql else "1",
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("model_id", name="uq_model_pricing_model_id"),
)
# 3. 创建 billing_log 表
op.create_table(
"billing_log",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("model_id", sa.String(), nullable=False),
sa.Column(
"prompt_tokens", sa.Integer(), server_default="0"
),
sa.Column(
"completion_tokens", sa.Integer(), server_default="0"
),
sa.Column(
"total_cost",
numeric_type,
nullable=False,
),
sa.Column(
"balance_after",
numeric_type,
nullable=True,
),
sa.Column(
"log_type",
sa.String(20),
server_default="deduct",
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# 创建 billing_log 索引
op.create_index(
"idx_billing_log_user_id", "billing_log", ["user_id"], unique=False
)
op.create_index(
"idx_billing_log_created_at", "billing_log", ["created_at"], unique=False
)
op.create_index(
"idx_billing_log_user_created",
"billing_log",
["user_id", "created_at"],
unique=False,
)
# 4. 创建 recharge_log 表
op.create_table(
"recharge_log",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column(
"amount",
numeric_type,
nullable=False,
),
sa.Column("operator_id", sa.String(), nullable=False),
sa.Column("remark", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# 创建 recharge_log 索引
op.create_index(
"idx_recharge_log_user_id", "recharge_log", ["user_id"], unique=False
)
op.create_index(
"idx_recharge_log_created_at", "recharge_log", ["created_at"], unique=False
)
def downgrade():
"""降级数据库:移除计费系统"""
# 删除索引
op.drop_index("idx_recharge_log_created_at", "recharge_log")
op.drop_index("idx_recharge_log_user_id", "recharge_log")
op.drop_index("idx_billing_log_user_created", "billing_log")
op.drop_index("idx_billing_log_created_at", "billing_log")
op.drop_index("idx_billing_log_user_id", "billing_log")
# 删除表
op.drop_table("recharge_log")
op.drop_table("billing_log")
op.drop_table("model_pricing")
# 删除 user 表字段
op.drop_column("user", "billing_status")
op.drop_column("user", "total_consumed")
op.drop_column("user", "balance")

View file

@ -0,0 +1,210 @@
"""Change billing amounts to integer (milli-yuan precision)
Revision ID: f8c9d0e4a3b2
Revises: e5f8a9b3c2d1
Create Date: 2025-12-06 12:00:00.000000
将金额字段从 DECIMAL/REAL 改为 INTEGER以毫为单位存储1 = 10000精度0.0001
- user.balance: Decimal -> Integer ()
- user.total_consumed: Decimal -> Integer ()
- model_pricing.input_price: Decimal -> Integer (/百万tokens)
- model_pricing.output_price: Decimal -> Integer (/百万tokens)
- billing_log.total_cost: Decimal -> Integer ()
- billing_log.balance_after: Decimal -> Integer ()
- recharge_log.amount: Decimal -> Integer ()
注意现有数据将乘以10000转换 ->
"""
from alembic import op
import sqlalchemy as sa
revision = "f8c9d0e4a3b2"
down_revision = "607801a77d0d" # merge billing and reply_to heads
branch_labels = None
depends_on = None
def upgrade():
"""升级数据库:将金额字段改为整数(分)"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
if is_sqlite:
# SQLite 不支持直接修改列类型,需要重建表
# 1. user 表
with op.batch_alter_table("user") as batch_op:
# SQLite: 先转换现有数据(元 * 10000 = 毫)
connection.execute(sa.text("""
UPDATE user
SET balance = CAST(balance * 10000 AS INTEGER),
total_consumed = CAST(total_consumed * 10000 AS INTEGER)
"""))
# 修改列类型
batch_op.alter_column("balance",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False,
existing_server_default="0")
batch_op.alter_column("total_consumed",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False,
existing_server_default="0")
# 2. model_pricing 表
with op.batch_alter_table("model_pricing") as batch_op:
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = CAST(input_price * 10000 AS INTEGER),
output_price = CAST(output_price * 10000 AS INTEGER)
"""))
batch_op.alter_column("input_price",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
batch_op.alter_column("output_price",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
# 3. billing_log 表
with op.batch_alter_table("billing_log") as batch_op:
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = CAST(total_cost * 10000 AS INTEGER),
balance_after = CAST(COALESCE(balance_after, 0) * 10000 AS INTEGER)
"""))
batch_op.alter_column("total_cost",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
batch_op.alter_column("balance_after",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=True)
# 4. recharge_log 表
with op.batch_alter_table("recharge_log") as batch_op:
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = CAST(amount * 10000 AS INTEGER)
"""))
batch_op.alter_column("amount",
existing_type=sa.REAL,
type_=sa.Integer(),
existing_nullable=False)
else:
# PostgreSQL 支持直接修改
# 1. 先转换数据(元 * 10000 = 毫)
connection.execute(sa.text("""
UPDATE "user"
SET balance = CAST(balance * 10000 AS INTEGER),
total_consumed = CAST(total_consumed * 10000 AS INTEGER)
"""))
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = CAST(input_price * 10000 AS INTEGER),
output_price = CAST(output_price * 10000 AS INTEGER)
"""))
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = CAST(total_cost * 10000 AS INTEGER),
balance_after = CAST(COALESCE(balance_after, 0) * 10000 AS INTEGER)
"""))
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = CAST(amount * 10000 AS INTEGER)
"""))
# 2. 修改列类型
op.alter_column("user", "balance", type_=sa.Integer())
op.alter_column("user", "total_consumed", type_=sa.Integer())
op.alter_column("model_pricing", "input_price", type_=sa.Integer())
op.alter_column("model_pricing", "output_price", type_=sa.Integer())
op.alter_column("billing_log", "total_cost", type_=sa.Integer())
op.alter_column("billing_log", "balance_after", type_=sa.Integer())
op.alter_column("recharge_log", "amount", type_=sa.Integer())
def downgrade():
"""降级数据库:将整数(分)改回 Decimal"""
connection = op.get_bind()
is_sqlite = connection.dialect.name == "sqlite"
is_postgresql = connection.dialect.name == "postgresql"
numeric_type = sa.NUMERIC(20, 6) if is_postgresql else sa.REAL
if is_sqlite:
# SQLite 降级
with op.batch_alter_table("user") as batch_op:
batch_op.alter_column("balance", type_=sa.REAL)
batch_op.alter_column("total_consumed", type_=sa.REAL)
with op.batch_alter_table("model_pricing") as batch_op:
batch_op.alter_column("input_price", type_=sa.REAL)
batch_op.alter_column("output_price", type_=sa.REAL)
with op.batch_alter_table("billing_log") as batch_op:
batch_op.alter_column("total_cost", type_=sa.REAL)
batch_op.alter_column("balance_after", type_=sa.REAL)
with op.batch_alter_table("recharge_log") as batch_op:
batch_op.alter_column("amount", type_=sa.REAL)
# 转换数据(毫 / 10000 = 元)
connection.execute(sa.text("""
UPDATE user
SET balance = CAST(balance AS REAL) / 10000.0,
total_consumed = CAST(total_consumed AS REAL) / 10000.0
"""))
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = CAST(input_price AS REAL) / 10000.0,
output_price = CAST(output_price AS REAL) / 10000.0
"""))
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = CAST(total_cost AS REAL) / 10000.0,
balance_after = CAST(balance_after AS REAL) / 10000.0
"""))
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = CAST(amount AS REAL) / 10000.0
"""))
else:
# PostgreSQL 降级
op.alter_column("user", "balance", type_=numeric_type)
op.alter_column("user", "total_consumed", type_=numeric_type)
op.alter_column("model_pricing", "input_price", type_=numeric_type)
op.alter_column("model_pricing", "output_price", type_=numeric_type)
op.alter_column("billing_log", "total_cost", type_=numeric_type)
op.alter_column("billing_log", "balance_after", type_=numeric_type)
op.alter_column("recharge_log", "amount", type_=numeric_type)
connection.execute(sa.text("""
UPDATE "user"
SET balance = balance / 10000.0,
total_consumed = total_consumed / 10000.0
"""))
connection.execute(sa.text("""
UPDATE model_pricing
SET input_price = input_price / 10000.0,
output_price = output_price / 10000.0
"""))
connection.execute(sa.text("""
UPDATE billing_log
SET total_cost = total_cost / 10000.0,
balance_after = balance_after / 10000.0
"""))
connection.execute(sa.text("""
UPDATE recharge_log
SET amount = amount / 10000.0
"""))

View file

@ -0,0 +1,276 @@
"""
计费模块数据模型
包含模型定价计费日志充值日志的 ORM 模型和数据访问层
"""
import time
import uuid
from typing import Optional
from pydantic import BaseModel, ConfigDict
from sqlalchemy import Boolean, Column, String, Integer, BigInteger, Text
from open_webui.internal.db import Base, get_db
####################
# ModelPricing DB Schema
####################
class ModelPricing(Base):
"""模型定价表"""
__tablename__ = "model_pricing"
id = Column(String, primary_key=True)
model_id = Column(String, unique=True, nullable=False) # 模型标识,如 "gpt-4o"
input_price = Column(Integer, nullable=False) # 输入价格(分/1k token
output_price = Column(Integer, nullable=False) # 输出价格(分/1k token
enabled = Column(Boolean, default=True, nullable=False) # 是否启用
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
class BillingLog(Base):
"""计费日志表"""
__tablename__ = "billing_log"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False, index=True)
model_id = Column(String, nullable=False)
prompt_tokens = Column(Integer, default=0)
completion_tokens = Column(Integer, default=0)
total_cost = Column(Integer, nullable=False) # 本次费用(分)
balance_after = Column(Integer) # 扣费后余额(分)
log_type = Column(String(20), default="deduct") # deduct/refund/precharge/settle
created_at = Column(BigInteger, nullable=False, index=True)
# 预扣费相关字段
precharge_id = Column(String, nullable=True, index=True) # 预扣费事务ID
status = Column(String(20), nullable=True, default="settled") # precharge | settled | refunded
estimated_tokens = Column(Integer, nullable=True) # 预估tokens总数
refund_amount = Column(Integer, nullable=True) # 退款金额(毫)
class RechargeLog(Base):
"""充值日志表"""
__tablename__ = "recharge_log"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False, index=True)
amount = Column(Integer, nullable=False) # 充值金额(分)
operator_id = Column(String, nullable=False) # 操作员ID
remark = Column(Text) # 备注
created_at = Column(BigInteger, nullable=False)
####################
# Pydantic Models
####################
class ModelPricingModel(BaseModel):
"""模型定价 Pydantic 模型(以分为单位)"""
id: str
model_id: str
input_price: int # 分/1k tokens
output_price: int # 分/1k tokens
enabled: bool
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
class BillingLogModel(BaseModel):
"""计费日志 Pydantic 模型(以分为单位)"""
id: str
user_id: str
model_id: str
prompt_tokens: int
completion_tokens: int
total_cost: int # 分
balance_after: Optional[int] # 分
log_type: str
created_at: int
# 预扣费相关字段
precharge_id: Optional[str] = None
status: Optional[str] = "settled"
estimated_tokens: Optional[int] = None
refund_amount: Optional[int] = None
model_config = ConfigDict(from_attributes=True)
class RechargeLogModel(BaseModel):
"""充值日志 Pydantic 模型(以分为单位)"""
id: str
user_id: str
amount: int # 分
operator_id: str
remark: Optional[str]
created_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Data Access Layer
####################
class ModelPricingTable:
"""模型定价数据访问层"""
def get_by_model_id(self, model_id: str) -> Optional[ModelPricingModel]:
"""根据模型ID获取定价"""
try:
with get_db() as db:
pricing = (
db.query(ModelPricing)
.filter_by(model_id=model_id, enabled=True)
.first()
)
return ModelPricingModel.model_validate(pricing) if pricing else None
except Exception:
return None
def get_all(self) -> list[ModelPricingModel]:
"""获取所有定价"""
with get_db() as db:
pricings = db.query(ModelPricing).filter_by(enabled=True).all()
return [ModelPricingModel.model_validate(p) for p in pricings]
def upsert(
self, model_id: str, input_price: int, output_price: int
) -> ModelPricingModel:
"""创建或更新定价"""
with get_db() as db:
existing = db.query(ModelPricing).filter_by(model_id=model_id).first()
now = int(time.time())
if existing:
# 更新
existing.input_price = input_price
existing.output_price = output_price
existing.updated_at = now
db.commit()
db.refresh(existing)
return ModelPricingModel.model_validate(existing)
else:
# 创建
new_pricing = ModelPricing(
id=str(uuid.uuid4()),
model_id=model_id,
input_price=input_price,
output_price=output_price,
enabled=True,
created_at=now,
updated_at=now,
)
db.add(new_pricing)
db.commit()
db.refresh(new_pricing)
return ModelPricingModel.model_validate(new_pricing)
def delete_by_model_id(self, model_id: str) -> bool:
"""删除定价(软删除,设置 enabled=False"""
try:
with get_db() as db:
result = (
db.query(ModelPricing)
.filter_by(model_id=model_id)
.update({"enabled": False})
)
db.commit()
return result > 0
except Exception:
return False
class BillingLogTable:
"""计费日志数据访问层"""
def get_by_user_id(
self, user_id: str, limit: int = 50, offset: int = 0
) -> list[BillingLogModel]:
"""获取用户计费日志"""
with get_db() as db:
logs = (
db.query(BillingLog)
.filter_by(user_id=user_id)
.order_by(BillingLog.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
return [BillingLogModel.model_validate(log) for log in logs]
def count_by_user_id(self, user_id: str) -> int:
"""统计用户日志数量"""
with get_db() as db:
return db.query(BillingLog).filter_by(user_id=user_id).count()
class RechargeLogTable:
"""充值日志数据访问层"""
def get_by_user_id(
self, user_id: str, limit: int = 50, offset: int = 0
) -> list[RechargeLogModel]:
"""获取用户充值日志"""
with get_db() as db:
logs = (
db.query(RechargeLog)
.filter_by(user_id=user_id)
.order_by(RechargeLog.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
return [RechargeLogModel.model_validate(log) for log in logs]
def get_by_user_id_with_operator_name(
self, user_id: str, limit: int = 50, offset: int = 0
) -> list[dict]:
"""获取用户充值日志,包含操作员姓名"""
from open_webui.models.users import User
with get_db() as db:
logs = (
db.query(RechargeLog, User.name.label("operator_name"))
.join(User, RechargeLog.operator_id == User.id)
.filter(RechargeLog.user_id == user_id)
.order_by(RechargeLog.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
# 转换为字典格式
return [
{
"id": log.RechargeLog.id,
"user_id": log.RechargeLog.user_id,
"amount": log.RechargeLog.amount, # 整数(分)
"operator_id": log.RechargeLog.operator_id,
"operator_name": log.operator_name,
"remark": log.RechargeLog.remark,
"created_at": log.RechargeLog.created_at,
}
for log in logs
]
# 单例实例
ModelPricings = ModelPricingTable()
BillingLogs = BillingLogTable()
RechargeLogs = RechargeLogTable()

View file

@ -11,7 +11,7 @@ from open_webui.utils.misc import throttle
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Date
from sqlalchemy import BigInteger, Column, String, Text, Date, Integer
from sqlalchemy import or_
import datetime
@ -48,6 +48,11 @@ class User(Base):
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
# 计费相关字段(以分为单位存储)
balance = Column(Integer, default=0, nullable=False) # 账户余额(分)
total_consumed = Column(Integer, default=0, nullable=False) # 累计消费(分)
billing_status = Column(String(20), default="active", nullable=False) # active/frozen
class UserSettings(BaseModel):
ui: Optional[dict] = {}
@ -79,6 +84,11 @@ class UserModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
# 计费相关字段(以分为单位)
balance: Optional[int] = 0
total_consumed: Optional[int] = 0
billing_status: Optional[str] = "active"
model_config = ConfigDict(from_attributes=True)
@ -272,6 +282,11 @@ class UsersTable:
query = query.order_by(User.role.asc())
else:
query = query.order_by(User.role.desc())
elif order_by == "balance":
if direction == "asc":
query = query.order_by(User.balance.asc())
else:
query = query.order_by(User.balance.desc())
else:
query = query.order_by(User.created_at.desc())