open-webui/backend/open_webui/routers/billing.py
2025-12-08 09:37:01 +08:00

637 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
计费管理 API 路由
提供余额查询、充值、消费记录、统计报表、模型定价等管理接口
"""
import time
import logging
from decimal import Decimal
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from open_webui.models.users import Users, User
from open_webui.models.billing import ModelPricings, BillingLogs, RechargeLogs, BillingLog
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.billing import recharge_user
from open_webui.internal.db import get_db
log = logging.getLogger(__name__)
router = APIRouter()
####################
# Request/Response Models
####################
class BalanceResponse(BaseModel):
"""余额响应"""
balance: float = Field(..., description="当前余额(元)")
total_consumed: float = Field(..., description="累计消费(元)")
billing_status: str = Field(..., description="账户状态: active/frozen")
class RechargeRequest(BaseModel):
"""充值请求"""
user_id: str = Field(..., description="用户ID")
amount: int = Field(..., ne=0, description="充值/扣费金额1元 = 10000毫正数充值负数扣费")
remark: str = Field(default="", description="备注")
class RechargeResponse(BaseModel):
"""充值响应"""
balance: float = Field(..., description="充值后余额1元 = 10000毫")
status: str = Field(..., description="账户状态")
class BillingLogResponse(BaseModel):
"""计费日志响应"""
id: str
model_id: str
cost: float
balance_after: Optional[float]
type: str
prompt_tokens: int
completion_tokens: int
created_at: int
precharge_id: Optional[str] = None # 预扣费事务ID用于关联 precharge 和 settle
class PricingRequest(BaseModel):
"""定价请求"""
model_id: str = Field(..., description="模型标识")
input_price: Decimal = Field(..., ge=0, description="输入价格(元/百万token")
output_price: Decimal = Field(..., ge=0, description="输出价格(元/百万token")
class PricingResponse(BaseModel):
"""定价响应"""
model_id: str
input_price: float
output_price: float
source: str = Field(..., description="来源: database/default")
class DailyStats(BaseModel):
"""每日统计"""
date: str
cost: float
by_model: dict[str, float] = {} # 按模型分组的消费
class ModelStats(BaseModel):
"""模型统计"""
model: str
cost: float
count: int
class StatsResponse(BaseModel):
"""统计报表响应"""
daily: list[DailyStats]
by_model: list[ModelStats]
models: list[str] = [] # 所有模型列表(用于前端生成堆叠图系列)
####################
# API Endpoints
####################
@router.get("/balance", response_model=BalanceResponse)
async def get_balance(user=Depends(get_verified_user)):
"""
查询当前用户余额
需要登录
"""
user_data = Users.get_user_by_id(user.id)
if not user_data:
raise HTTPException(status_code=404, detail="用户不存在")
return BalanceResponse(
balance=float(user_data.balance or 0),
total_consumed=float(user_data.total_consumed or 0),
billing_status=user_data.billing_status or "active",
)
@router.post("/recharge", response_model=RechargeResponse)
async def recharge(req: RechargeRequest, admin=Depends(get_admin_user)):
"""
管理员充值
需要管理员权限
"""
try:
balance = recharge_user(
user_id=req.user_id,
amount=req.amount,
operator_id=admin.id,
remark=req.remark,
)
# 获取用户状态
user_data = Users.get_user_by_id(req.user_id)
if not user_data:
raise HTTPException(status_code=404, detail="用户不存在")
return RechargeResponse(
balance=float(balance), status=user_data.billing_status or "active"
)
except HTTPException:
raise
except Exception as e:
log.error(f"充值失败: {e}")
raise HTTPException(status_code=500, detail=f"充值失败: {str(e)}")
@router.get("/logs", response_model=list[BillingLogResponse])
async def get_logs(
user=Depends(get_verified_user), limit: int = 50, offset: int = 0
):
"""
查询当前用户消费记录
需要登录
"""
try:
logs = BillingLogs.get_by_user_id(user.id, limit=limit, offset=offset)
return [
BillingLogResponse(
id=log.id,
model_id=log.model_id,
cost=float(log.total_cost),
balance_after=float(log.balance_after) if log.balance_after else None,
type=log.log_type,
prompt_tokens=log.prompt_tokens,
completion_tokens=log.completion_tokens,
created_at=log.created_at,
precharge_id=log.precharge_id, # 添加预扣费事务ID
)
for log in logs
]
except Exception as e:
log.error(f"查询日志失败: {e}")
raise HTTPException(status_code=500, detail=f"查询日志失败: {str(e)}")
@router.get("/stats", response_model=StatsResponse)
async def get_stats(
user=Depends(get_verified_user),
days: int = 7,
granularity: str = "day"
):
"""
查询统计报表
Args:
days: 查询天数
granularity: 时间粒度 (hour/day/month)
需要登录
"""
try:
from sqlalchemy import func
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
with get_db() as db:
now = datetime.now()
cutoff = int((time.time() - days * 86400) * 1000000000)
# 根据粒度选择分组方式和生成完整时间序列(包含当前时段)
if granularity == "hour":
trunc_unit = "hour"
date_format = "%H:00"
# 生成过去24小时的完整序列包含当前小时
all_periods = []
for i in range(23, -1, -1):
dt = now - timedelta(hours=i)
all_periods.append(dt.replace(minute=0, second=0, microsecond=0))
elif granularity == "month":
trunc_unit = "month"
date_format = "%Y-%m"
# 生成过去12个月的完整序列包含当前月
all_periods = []
for i in range(11, -1, -1):
dt = now - relativedelta(months=i)
all_periods.append(dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0))
else:
# 默认按天分组
trunc_unit = "day"
date_format = "%m-%d"
# 生成过去N天的完整序列包含今天
all_periods = []
for i in range(days - 1, -1, -1):
dt = now - timedelta(days=i)
all_periods.append(dt.replace(hour=0, minute=0, second=0, microsecond=0))
# 按时间+模型分组统计(用于堆叠图)
daily_by_model_query = (
db.query(
func.date_trunc(
trunc_unit,
func.to_timestamp(BillingLog.created_at / 1000000000)
).label("date"),
BillingLog.model_id,
func.sum(BillingLog.total_cost).label("total"),
)
.filter(
BillingLog.user_id == user.id,
BillingLog.created_at >= cutoff,
BillingLog.log_type.in_(["deduct", "settle"]),
)
.group_by("date", BillingLog.model_id)
.order_by("date")
.all()
)
# 构建数据结构: {date_key: {model_id: cost, ...}, ...}
data_dict: dict[str, dict[str, float]] = {}
all_models: set[str] = set()
for d in daily_by_model_query:
if d[0] and d[1]:
date_key = d[0].strftime(date_format)
model_id = d[1]
cost = d[2] / 10000 if d[2] else 0
all_models.add(model_id)
if date_key not in data_dict:
data_dict[date_key] = {}
data_dict[date_key][model_id] = cost
log.debug(f"统计查询: granularity={granularity}, days={days}, 记录数={len(daily_by_model_query)}, 模型数={len(all_models)}")
# 填充完整时间序列
daily_stats = []
for period in all_periods:
key = period.strftime(date_format)
by_model = data_dict.get(key, {})
total_cost = sum(by_model.values())
daily_stats.append(DailyStats(date=key, cost=total_cost, by_model=by_model))
log.debug(f"生成时间序列: 数量={len(daily_stats)}, 模型列表={list(all_models)}")
# 按模型统计
by_model_query = (
db.query(
BillingLog.model_id,
func.sum(BillingLog.total_cost).label("total"),
func.count().label("count"),
)
.filter(
BillingLog.user_id == user.id,
BillingLog.created_at >= cutoff,
BillingLog.log_type.in_(["deduct", "settle"]),
)
.group_by(BillingLog.model_id)
.order_by(func.sum(BillingLog.total_cost).desc())
.all()
)
return StatsResponse(
daily=daily_stats,
by_model=[
ModelStats(model=m[0], cost=m[1] / 10000 if m[1] else 0, count=m[2])
for m in by_model_query
],
models=sorted(list(all_models)), # 按字母排序的模型列表
)
except Exception as e:
log.error(f"查询统计失败: {e}")
raise HTTPException(status_code=500, detail=f"查询统计失败: {str(e)}")
@router.post("/pricing", response_model=PricingResponse)
async def set_pricing(req: PricingRequest, admin=Depends(get_admin_user)):
"""
设置模型定价
需要管理员权限
"""
try:
pricing = ModelPricings.upsert(
model_id=req.model_id,
input_price=req.input_price,
output_price=req.output_price,
)
return PricingResponse(
model_id=pricing.model_id,
input_price=float(pricing.input_price),
output_price=float(pricing.output_price),
source="database",
)
except Exception as e:
log.error(f"设置定价失败: {e}")
raise HTTPException(status_code=500, detail=f"设置定价失败: {str(e)}")
@router.get("/pricing/{model_id}", response_model=PricingResponse)
async def get_pricing(model_id: str):
"""
查询模型定价
公开接口,无需登录
"""
try:
pricing = ModelPricings.get_by_model_id(model_id)
if pricing:
return PricingResponse(
model_id=pricing.model_id,
input_price=float(pricing.input_price),
output_price=float(pricing.output_price),
source="database",
)
else:
# 返回默认价格
from open_webui.utils.billing import DEFAULT_PRICING
default = DEFAULT_PRICING.get(model_id, DEFAULT_PRICING["default"])
return PricingResponse(
model_id=model_id,
input_price=float(default["input"]),
output_price=float(default["output"]),
source="default",
)
except Exception as e:
log.error(f"查询定价失败: {e}")
raise HTTPException(status_code=500, detail=f"查询定价失败: {str(e)}")
@router.get("/pricing", response_model=list[PricingResponse])
async def list_pricing():
"""
列出所有模型定价
公开接口,无需登录
"""
try:
pricings = ModelPricings.get_all()
return [
PricingResponse(
model_id=p.model_id,
input_price=float(p.input_price),
output_price=float(p.output_price),
source="database",
)
for p in pricings
]
except Exception as e:
log.error(f"列出定价失败: {e}")
raise HTTPException(status_code=500, detail=f"列出定价失败: {str(e)}")
@router.get("/recharge/logs/{user_id}")
async def get_recharge_logs(
user_id: str,
limit: int = 50,
offset: int = 0,
admin=Depends(get_admin_user)
):
"""
查询用户充值记录 (仅管理员)
需要管理员权限
"""
try:
logs = RechargeLogs.get_by_user_id_with_operator_name(
user_id, limit=limit, offset=offset
)
return logs
except Exception as e:
log.error(f"查询充值记录失败: {e}")
raise HTTPException(status_code=500, detail=f"查询充值记录失败: {str(e)}")
####################
# Payment API (用户自助充值)
####################
class CreateOrderRequest(BaseModel):
"""创建充值订单请求"""
amount: float = Field(..., gt=0, le=10000, description="充值金额1-10000")
class CreateOrderResponse(BaseModel):
"""创建充值订单响应"""
order_id: str
out_trade_no: str
qr_code: str
amount: float
expired_at: int
class OrderStatusResponse(BaseModel):
"""订单状态响应"""
order_id: str
status: str
amount: float
paid_at: Optional[int] = None
@router.post("/payment/create", response_model=CreateOrderResponse)
async def create_payment_order(req: CreateOrderRequest, user=Depends(get_verified_user)):
"""
创建充值订单(生成支付二维码)
需要登录
"""
import uuid as uuid_module
from open_webui.utils.alipay import create_qr_payment, is_alipay_configured
from open_webui.models.billing import PaymentOrders
# 检查支付宝配置
if not is_alipay_configured():
raise HTTPException(status_code=503, detail="支付功能暂未开放,请联系管理员")
# 验证金额
if req.amount < 1:
raise HTTPException(status_code=400, detail="充值金额最低1元")
if req.amount > 10000:
raise HTTPException(status_code=400, detail="充值金额最高10000元")
# 生成订单号: CK + 时间戳 + 随机字符
out_trade_no = f"CK{int(time.time())}{uuid_module.uuid4().hex[:8].upper()}"
# 调用支付宝创建订单
success, msg, qr_code = create_qr_payment(
out_trade_no=out_trade_no,
amount_yuan=req.amount,
subject="Cakumi账户充值",
)
if not success:
log.error(f"创建支付订单失败: {msg}")
raise HTTPException(status_code=500, detail=f"创建订单失败: {msg}")
# 保存订单到数据库
now = int(time.time())
expired_at = now + 900 # 15分钟后过期
order = PaymentOrders.create(
user_id=user.id,
out_trade_no=out_trade_no,
amount=int(req.amount * 10000), # 元 → 毫
qr_code=qr_code,
expired_at=expired_at,
payment_method="alipay",
)
log.info(f"创建支付订单成功: {out_trade_no}, 用户={user.id}, 金额={req.amount}")
return CreateOrderResponse(
order_id=order.id,
out_trade_no=out_trade_no,
qr_code=qr_code,
amount=req.amount,
expired_at=expired_at,
)
@router.get("/payment/status/{order_id}", response_model=OrderStatusResponse)
async def get_payment_status(order_id: str, user=Depends(get_verified_user)):
"""
查询订单状态(前端轮询)
需要登录
"""
from open_webui.models.billing import PaymentOrders
order = PaymentOrders.get_by_id(order_id)
if not order:
raise HTTPException(status_code=404, detail="订单不存在")
# 确保只能查询自己的订单
if order.user_id != user.id:
raise HTTPException(status_code=403, detail="无权查询该订单")
return OrderStatusResponse(
order_id=order.id,
status=order.status,
amount=order.amount / 10000, # 毫 → 元
paid_at=order.paid_at,
)
@router.post("/payment/notify")
async def alipay_notify(request):
"""
支付宝异步通知回调
注意:此接口无需登录验证
"""
from fastapi import Request
from open_webui.utils.alipay import verify_notify_sign
from open_webui.models.billing import PaymentOrders, RechargeLog
from open_webui.models.users import Users
import uuid as uuid_module
# 获取回调参数
form_data = await request.form()
params = dict(form_data)
log.info(f"收到支付宝回调: {params.get('out_trade_no')}")
# 验签
if not verify_notify_sign(params):
log.error("支付宝回调验签失败")
return "fail"
out_trade_no = params.get("out_trade_no")
trade_no = params.get("trade_no")
trade_status = params.get("trade_status")
# 只处理支付成功状态
if trade_status not in ["TRADE_SUCCESS", "TRADE_FINISHED"]:
log.info(f"支付宝回调,非成功状态: {trade_status}")
return "success"
# 查询订单
order = PaymentOrders.get_by_out_trade_no(out_trade_no)
if not order:
log.error(f"支付宝回调,订单不存在: {out_trade_no}")
return "success"
# 幂等检查:已处理的订单直接返回成功
if order.status == "paid":
log.info(f"支付宝回调,订单已处理: {out_trade_no}")
return "success"
# 更新订单状态
now = int(time.time())
PaymentOrders.update_status(
out_trade_no=out_trade_no,
status="paid",
trade_no=trade_no,
paid_at=now,
)
# 增加用户余额
try:
from open_webui.models.users import User as UserModel
with get_db() as db:
user = db.query(UserModel).filter_by(id=order.user_id).first()
if user:
user.balance = (user.balance or 0) + order.amount
db.commit()
# 记录充值日志
recharge_log = RechargeLog(
id=str(uuid_module.uuid4()),
user_id=order.user_id,
amount=order.amount,
operator_id="system", # 系统自动充值
remark=f"支付宝充值,订单号: {out_trade_no}",
created_at=now,
)
db.add(recharge_log)
db.commit()
log.info(
f"支付成功: 用户={order.user_id}, 金额={order.amount / 10000:.2f}元, "
f"订单={out_trade_no}"
)
except Exception as e:
log.error(f"支付回调处理失败: {e}")
# 即使余额更新失败,也返回 success避免支付宝重复回调
# 后续可通过定时任务修复
return "success"
@router.get("/payment/config")
async def get_payment_config():
"""
获取支付配置状态
公开接口,用于前端判断是否显示充值功能
"""
from open_webui.utils.alipay import is_alipay_configured
return {
"alipay_enabled": is_alipay_configured(),
}