From 5e65f98525051efdf0d6253001e703e47241a08a Mon Sep 17 00:00:00 2001 From: sylarchen1389 Date: Sun, 7 Dec 2025 12:27:37 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E8=A1=A5=E5=85=85=E9=A2=9D?= =?UTF-8?q?=E5=BA=A6=E7=9B=B8=E5=85=B3api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/open_webui/routers/billing.py | 361 ++++++++++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 backend/open_webui/routers/billing.py diff --git a/backend/open_webui/routers/billing.py b/backend/open_webui/routers/billing.py new file mode 100644 index 0000000000..2774c8453c --- /dev/null +++ b/backend/open_webui/routers/billing.py @@ -0,0 +1,361 @@ +""" +计费管理 API 路由 + +提供余额查询、充值、消费记录、统计报表、模型定价等管理接口 +""" + +import time +import logging +from decimal import Decimal +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from open_webui.models.users import Users, User +from open_webui.models.billing import ModelPricings, BillingLogs, RechargeLogs, BillingLog +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.billing import recharge_user +from open_webui.internal.db import get_db + +log = logging.getLogger(__name__) + +router = APIRouter() + + +#################### +# Request/Response Models +#################### + + +class BalanceResponse(BaseModel): + """余额响应""" + + balance: float = Field(..., description="当前余额(元)") + total_consumed: float = Field(..., description="累计消费(元)") + billing_status: str = Field(..., description="账户状态: active/frozen") + + +class RechargeRequest(BaseModel): + """充值请求""" + + user_id: str = Field(..., description="用户ID") + amount: Decimal = Field(..., gt=0, description="充值金额(元)") + remark: str = Field(default="", description="备注") + + +class RechargeResponse(BaseModel): + """充值响应""" + + balance: float = Field(..., description="充值后余额(元)") + status: str = Field(..., description="账户状态") + + +class BillingLogResponse(BaseModel): + """计费日志响应""" + + id: str + model_id: str + cost: float + balance_after: Optional[float] + type: str + prompt_tokens: int + completion_tokens: int + created_at: int + precharge_id: Optional[str] = None # 预扣费事务ID,用于关联 precharge 和 settle + + +class PricingRequest(BaseModel): + """定价请求""" + + model_id: str = Field(..., description="模型标识") + input_price: Decimal = Field(..., ge=0, description="输入价格(元/百万token)") + output_price: Decimal = Field(..., ge=0, description="输出价格(元/百万token)") + + +class PricingResponse(BaseModel): + """定价响应""" + + model_id: str + input_price: float + output_price: float + source: str = Field(..., description="来源: database/default") + + +class DailyStats(BaseModel): + """每日统计""" + + date: str + cost: float + + +class ModelStats(BaseModel): + """模型统计""" + + model: str + cost: float + count: int + + +class StatsResponse(BaseModel): + """统计报表响应""" + + daily: list[DailyStats] + by_model: list[ModelStats] + + +#################### +# API Endpoints +#################### + + +@router.get("/balance", response_model=BalanceResponse) +async def get_balance(user=Depends(get_verified_user)): + """ + 查询当前用户余额 + + 需要登录 + """ + user_data = Users.get_user_by_id(user.id) + if not user_data: + raise HTTPException(status_code=404, detail="用户不存在") + + return BalanceResponse( + balance=float(user_data.balance or 0), + total_consumed=float(user_data.total_consumed or 0), + billing_status=user_data.billing_status or "active", + ) + + +@router.post("/recharge", response_model=RechargeResponse) +async def recharge(req: RechargeRequest, admin=Depends(get_admin_user)): + """ + 管理员充值 + + 需要管理员权限 + """ + try: + balance = recharge_user( + user_id=req.user_id, + amount=req.amount, + operator_id=admin.id, + remark=req.remark, + ) + + # 获取用户状态 + user_data = Users.get_user_by_id(req.user_id) + if not user_data: + raise HTTPException(status_code=404, detail="用户不存在") + + return RechargeResponse( + balance=float(balance), status=user_data.billing_status or "active" + ) + except HTTPException: + raise + except Exception as e: + log.error(f"充值失败: {e}") + raise HTTPException(status_code=500, detail=f"充值失败: {str(e)}") + + +@router.get("/logs", response_model=list[BillingLogResponse]) +async def get_logs( + user=Depends(get_verified_user), limit: int = 50, offset: int = 0 +): + """ + 查询当前用户消费记录 + + 需要登录 + """ + try: + logs = BillingLogs.get_by_user_id(user.id, limit=limit, offset=offset) + + return [ + BillingLogResponse( + id=log.id, + model_id=log.model_id, + cost=float(log.total_cost), + balance_after=float(log.balance_after) if log.balance_after else None, + type=log.log_type, + prompt_tokens=log.prompt_tokens, + completion_tokens=log.completion_tokens, + created_at=log.created_at, + precharge_id=log.precharge_id, # 添加预扣费事务ID + ) + for log in logs + ] + except Exception as e: + log.error(f"查询日志失败: {e}") + raise HTTPException(status_code=500, detail=f"查询日志失败: {str(e)}") + + +@router.get("/stats", response_model=StatsResponse) +async def get_stats(user=Depends(get_verified_user), days: int = 7): + """ + 查询统计报表 + + 按日统计和按模型统计 + + 需要登录 + """ + try: + from sqlalchemy import func + + with get_db() as db: + # cutoff: 纳秒级时间戳 + cutoff = int((time.time() - days * 86400) * 1000000000) + + # 按日统计 + daily_query = ( + db.query( + func.date_trunc( + "day", + # created_at 是纳秒级,需要除以 1000000000 转换为秒级 + func.to_timestamp(BillingLog.created_at / 1000000000) + ).label("date"), + func.sum(BillingLog.total_cost).label("total"), + ) + .filter( + BillingLog.user_id == user.id, + BillingLog.created_at >= cutoff, + BillingLog.log_type == "deduct", + ) + .group_by("date") + .order_by("date") + .all() + ) + + # 按模型统计 + by_model_query = ( + db.query( + BillingLog.model_id, + func.sum(BillingLog.total_cost).label("total"), + func.count().label("count"), + ) + .filter( + BillingLog.user_id == user.id, + BillingLog.created_at >= cutoff, + BillingLog.log_type == "deduct", + ) + .group_by(BillingLog.model_id) + .order_by(func.sum(BillingLog.total_cost).desc()) + .all() + ) + + # cost 单位转换:毫 → 元(除以 10000) + return StatsResponse( + daily=[ + DailyStats(date=str(d[0].date()), cost=d[1] / 10000 if d[1] else 0) + for d in daily_query + ], + by_model=[ + ModelStats(model=m[0], cost=m[1] / 10000 if m[1] else 0, count=m[2]) + for m in by_model_query + ], + ) + except Exception as e: + log.error(f"查询统计失败: {e}") + raise HTTPException(status_code=500, detail=f"查询统计失败: {str(e)}") + + +@router.post("/pricing", response_model=PricingResponse) +async def set_pricing(req: PricingRequest, admin=Depends(get_admin_user)): + """ + 设置模型定价 + + 需要管理员权限 + """ + try: + pricing = ModelPricings.upsert( + model_id=req.model_id, + input_price=req.input_price, + output_price=req.output_price, + ) + + return PricingResponse( + model_id=pricing.model_id, + input_price=float(pricing.input_price), + output_price=float(pricing.output_price), + source="database", + ) + except Exception as e: + log.error(f"设置定价失败: {e}") + raise HTTPException(status_code=500, detail=f"设置定价失败: {str(e)}") + + +@router.get("/pricing/{model_id}", response_model=PricingResponse) +async def get_pricing(model_id: str): + """ + 查询模型定价 + + 公开接口,无需登录 + """ + try: + pricing = ModelPricings.get_by_model_id(model_id) + + if pricing: + return PricingResponse( + model_id=pricing.model_id, + input_price=float(pricing.input_price), + output_price=float(pricing.output_price), + source="database", + ) + else: + # 返回默认价格 + from open_webui.utils.billing import DEFAULT_PRICING + + default = DEFAULT_PRICING.get(model_id, DEFAULT_PRICING["default"]) + return PricingResponse( + model_id=model_id, + input_price=float(default["input"]), + output_price=float(default["output"]), + source="default", + ) + except Exception as e: + log.error(f"查询定价失败: {e}") + raise HTTPException(status_code=500, detail=f"查询定价失败: {str(e)}") + + +@router.get("/pricing", response_model=list[PricingResponse]) +async def list_pricing(): + """ + 列出所有模型定价 + + 公开接口,无需登录 + """ + try: + pricings = ModelPricings.get_all() + + return [ + PricingResponse( + model_id=p.model_id, + input_price=float(p.input_price), + output_price=float(p.output_price), + source="database", + ) + for p in pricings + ] + except Exception as e: + log.error(f"列出定价失败: {e}") + raise HTTPException(status_code=500, detail=f"列出定价失败: {str(e)}") + + +@router.get("/recharge/logs/{user_id}") +async def get_recharge_logs( + user_id: str, + limit: int = 50, + offset: int = 0, + admin=Depends(get_admin_user) +): + """ + 查询用户充值记录 (仅管理员) + + 需要管理员权限 + """ + try: + logs = RechargeLogs.get_by_user_id_with_operator_name( + user_id, limit=limit, offset=offset + ) + return logs + except Exception as e: + log.error(f"查询充值记录失败: {e}") + raise HTTPException(status_code=500, detail=f"查询充值记录失败: {str(e)}")