open-webui/backend/open_webui/routers/billing.py
2025-12-07 13:51:03 +08:00

361 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
计费管理 API 路由
提供余额查询、充值、消费记录、统计报表、模型定价等管理接口
"""
import time
import logging
from decimal import Decimal
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from open_webui.models.users import Users, User
from open_webui.models.billing import ModelPricings, BillingLogs, RechargeLogs, BillingLog
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.billing import recharge_user
from open_webui.internal.db import get_db
log = logging.getLogger(__name__)
router = APIRouter()
####################
# Request/Response Models
####################
class BalanceResponse(BaseModel):
"""余额响应"""
balance: float = Field(..., description="当前余额(元)")
total_consumed: float = Field(..., description="累计消费(元)")
billing_status: str = Field(..., description="账户状态: active/frozen")
class RechargeRequest(BaseModel):
"""充值请求"""
user_id: str = Field(..., description="用户ID")
amount: int = Field(..., ne=0, description="充值/扣费金额1元 = 10000毫正数充值负数扣费")
remark: str = Field(default="", description="备注")
class RechargeResponse(BaseModel):
"""充值响应"""
balance: float = Field(..., description="充值后余额1元 = 10000毫")
status: str = Field(..., description="账户状态")
class BillingLogResponse(BaseModel):
"""计费日志响应"""
id: str
model_id: str
cost: float
balance_after: Optional[float]
type: str
prompt_tokens: int
completion_tokens: int
created_at: int
precharge_id: Optional[str] = None # 预扣费事务ID用于关联 precharge 和 settle
class PricingRequest(BaseModel):
"""定价请求"""
model_id: str = Field(..., description="模型标识")
input_price: Decimal = Field(..., ge=0, description="输入价格(元/百万token")
output_price: Decimal = Field(..., ge=0, description="输出价格(元/百万token")
class PricingResponse(BaseModel):
"""定价响应"""
model_id: str
input_price: float
output_price: float
source: str = Field(..., description="来源: database/default")
class DailyStats(BaseModel):
"""每日统计"""
date: str
cost: float
class ModelStats(BaseModel):
"""模型统计"""
model: str
cost: float
count: int
class StatsResponse(BaseModel):
"""统计报表响应"""
daily: list[DailyStats]
by_model: list[ModelStats]
####################
# API Endpoints
####################
@router.get("/balance", response_model=BalanceResponse)
async def get_balance(user=Depends(get_verified_user)):
"""
查询当前用户余额
需要登录
"""
user_data = Users.get_user_by_id(user.id)
if not user_data:
raise HTTPException(status_code=404, detail="用户不存在")
return BalanceResponse(
balance=float(user_data.balance or 0),
total_consumed=float(user_data.total_consumed or 0),
billing_status=user_data.billing_status or "active",
)
@router.post("/recharge", response_model=RechargeResponse)
async def recharge(req: RechargeRequest, admin=Depends(get_admin_user)):
"""
管理员充值
需要管理员权限
"""
try:
balance = recharge_user(
user_id=req.user_id,
amount=req.amount,
operator_id=admin.id,
remark=req.remark,
)
# 获取用户状态
user_data = Users.get_user_by_id(req.user_id)
if not user_data:
raise HTTPException(status_code=404, detail="用户不存在")
return RechargeResponse(
balance=float(balance), status=user_data.billing_status or "active"
)
except HTTPException:
raise
except Exception as e:
log.error(f"充值失败: {e}")
raise HTTPException(status_code=500, detail=f"充值失败: {str(e)}")
@router.get("/logs", response_model=list[BillingLogResponse])
async def get_logs(
user=Depends(get_verified_user), limit: int = 50, offset: int = 0
):
"""
查询当前用户消费记录
需要登录
"""
try:
logs = BillingLogs.get_by_user_id(user.id, limit=limit, offset=offset)
return [
BillingLogResponse(
id=log.id,
model_id=log.model_id,
cost=float(log.total_cost),
balance_after=float(log.balance_after) if log.balance_after else None,
type=log.log_type,
prompt_tokens=log.prompt_tokens,
completion_tokens=log.completion_tokens,
created_at=log.created_at,
precharge_id=log.precharge_id, # 添加预扣费事务ID
)
for log in logs
]
except Exception as e:
log.error(f"查询日志失败: {e}")
raise HTTPException(status_code=500, detail=f"查询日志失败: {str(e)}")
@router.get("/stats", response_model=StatsResponse)
async def get_stats(user=Depends(get_verified_user), days: int = 7):
"""
查询统计报表
按日统计和按模型统计
需要登录
"""
try:
from sqlalchemy import func
with get_db() as db:
# cutoff: 纳秒级时间戳
cutoff = int((time.time() - days * 86400) * 1000000000)
# 按日统计
daily_query = (
db.query(
func.date_trunc(
"day",
# created_at 是纳秒级,需要除以 1000000000 转换为秒级
func.to_timestamp(BillingLog.created_at / 1000000000)
).label("date"),
func.sum(BillingLog.total_cost).label("total"),
)
.filter(
BillingLog.user_id == user.id,
BillingLog.created_at >= cutoff,
BillingLog.log_type == "deduct",
)
.group_by("date")
.order_by("date")
.all()
)
# 按模型统计
by_model_query = (
db.query(
BillingLog.model_id,
func.sum(BillingLog.total_cost).label("total"),
func.count().label("count"),
)
.filter(
BillingLog.user_id == user.id,
BillingLog.created_at >= cutoff,
BillingLog.log_type == "deduct",
)
.group_by(BillingLog.model_id)
.order_by(func.sum(BillingLog.total_cost).desc())
.all()
)
# cost 单位转换:毫 → 元(除以 10000
return StatsResponse(
daily=[
DailyStats(date=str(d[0].date()), cost=d[1] / 10000 if d[1] else 0)
for d in daily_query
],
by_model=[
ModelStats(model=m[0], cost=m[1] / 10000 if m[1] else 0, count=m[2])
for m in by_model_query
],
)
except Exception as e:
log.error(f"查询统计失败: {e}")
raise HTTPException(status_code=500, detail=f"查询统计失败: {str(e)}")
@router.post("/pricing", response_model=PricingResponse)
async def set_pricing(req: PricingRequest, admin=Depends(get_admin_user)):
"""
设置模型定价
需要管理员权限
"""
try:
pricing = ModelPricings.upsert(
model_id=req.model_id,
input_price=req.input_price,
output_price=req.output_price,
)
return PricingResponse(
model_id=pricing.model_id,
input_price=float(pricing.input_price),
output_price=float(pricing.output_price),
source="database",
)
except Exception as e:
log.error(f"设置定价失败: {e}")
raise HTTPException(status_code=500, detail=f"设置定价失败: {str(e)}")
@router.get("/pricing/{model_id}", response_model=PricingResponse)
async def get_pricing(model_id: str):
"""
查询模型定价
公开接口,无需登录
"""
try:
pricing = ModelPricings.get_by_model_id(model_id)
if pricing:
return PricingResponse(
model_id=pricing.model_id,
input_price=float(pricing.input_price),
output_price=float(pricing.output_price),
source="database",
)
else:
# 返回默认价格
from open_webui.utils.billing import DEFAULT_PRICING
default = DEFAULT_PRICING.get(model_id, DEFAULT_PRICING["default"])
return PricingResponse(
model_id=model_id,
input_price=float(default["input"]),
output_price=float(default["output"]),
source="default",
)
except Exception as e:
log.error(f"查询定价失败: {e}")
raise HTTPException(status_code=500, detail=f"查询定价失败: {str(e)}")
@router.get("/pricing", response_model=list[PricingResponse])
async def list_pricing():
"""
列出所有模型定价
公开接口,无需登录
"""
try:
pricings = ModelPricings.get_all()
return [
PricingResponse(
model_id=p.model_id,
input_price=float(p.input_price),
output_price=float(p.output_price),
source="database",
)
for p in pricings
]
except Exception as e:
log.error(f"列出定价失败: {e}")
raise HTTPException(status_code=500, detail=f"列出定价失败: {str(e)}")
@router.get("/recharge/logs/{user_id}")
async def get_recharge_logs(
user_id: str,
limit: int = 50,
offset: int = 0,
admin=Depends(get_admin_user)
):
"""
查询用户充值记录 (仅管理员)
需要管理员权限
"""
try:
logs = RechargeLogs.get_by_user_id_with_operator_name(
user_id, limit=limit, offset=offset
)
return logs
except Exception as e:
log.error(f"查询充值记录失败: {e}")
raise HTTPException(status_code=500, detail=f"查询充值记录失败: {str(e)}")