open-webui/backend/open_webui/models/models.py

466 lines
14 KiB
Python
Raw Normal View History

import logging
2024-08-27 22:10:27 +00:00
import time
from typing import Optional
2024-12-10 08:54:13 +00:00
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
2024-11-15 09:29:07 +00:00
from open_webui.models.groups import Groups
2024-12-10 08:54:13 +00:00
from open_webui.models.users import Users, UserResponse
2024-11-15 09:29:07 +00:00
2024-08-27 22:10:27 +00:00
from pydantic import BaseModel, ConfigDict
2024-11-15 09:29:07 +00:00
from sqlalchemy import Index
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, Integer, String
2024-05-24 07:26:00 +00:00
2024-11-15 09:29:07 +00:00
2024-11-17 00:51:55 +00:00
from open_webui.utils.access_control import has_access
2024-11-15 09:29:07 +00:00
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Models DB Schema
####################
# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
2024-05-24 06:47:01 +00:00
model_config = ConfigDict(extra="allow")
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
2024-07-09 06:07:23 +00:00
profile_image_url: Optional[str] = "/static/favicon.png"
2024-05-25 01:26:36 +00:00
2024-05-24 06:47:01 +00:00
description: Optional[str] = None
"""
User-facing description of the model.
"""
2024-05-25 06:34:58 +00:00
capabilities: Optional[dict] = None
2024-05-24 06:47:01 +00:00
model_config = ConfigDict(extra="allow")
pass
class Model(Base):
__tablename__ = "model"
id = Column(Text, primary_key=True)
"""模型唯一标识符,用于 API 调用"""
user_id = Column(Text)
"""模型创建者的用户 ID用于权限控制"""
base_model_id = Column(Text, nullable=True)
"""指向实际使用的基础模型的 IDNULL 表示基础模型"""
name = Column(Text)
"""人类可读的模型显示名称"""
icon_url = Column(Text, nullable=True)
"""模型图标 URL用于列表展示"""
provider = Column(String(50), nullable=True)
"""供应商标识,如 openai、anthropic、ollama"""
description = Column(Text, nullable=True)
"""模型简介文案,前端展示用"""
params = Column(JSONField)
"""模型运行参数,存储为 JSON 格式"""
meta = Column(JSONField)
"""模型元数据,存储为 JSON 格式"""
context_length = Column(Integer, nullable=False, default=4096)
"""上下文长度限制,默认 4096"""
tags = Column(JSONField, nullable=True)
"""模型标签,用于分组和筛选"""
2024-05-24 05:58:26 +00:00
sort_order = Column(Integer, nullable=False, default=0)
"""排序权重,数值越大越靠前"""
access_control = Column(JSON, nullable=True)
"""访问控制规则None=公开,{}=私有JSON=自定义权限"""
2024-11-15 02:57:25 +00:00
2024-11-16 02:21:41 +00:00
is_active = Column(Boolean, default=True)
"""模型激活状态False 表示已禁用"""
2024-11-16 02:21:41 +00:00
updated_at = Column(BigInteger)
"""最后更新时间戳Unix 时间戳)"""
created_at = Column(BigInteger)
"""创建时间戳Unix 时间戳)"""
__table_args__ = (Index("idx_model_provider", "provider"),)
class ModelModel(BaseModel):
id: str
2024-05-25 01:26:36 +00:00
user_id: str
base_model_id: Optional[str] = None
2024-05-24 07:26:00 +00:00
name: str
icon_url: Optional[str] = None
provider: Optional[str] = None
description: Optional[str] = None
params: ModelParams
2024-05-24 05:58:26 +00:00
meta: ModelMeta
context_length: int = 4096
tags: Optional[dict] = None
sort_order: int = 0
2024-11-15 04:13:43 +00:00
access_control: Optional[dict] = None
2024-11-15 02:57:25 +00:00
2024-11-16 02:21:41 +00:00
is_active: bool
2024-05-24 07:26:00 +00:00
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
2024-11-18 13:37:04 +00:00
class ModelUserResponse(ModelModel):
user: Optional[UserResponse] = None
2024-11-15 09:29:07 +00:00
2024-11-16 02:21:41 +00:00
2024-11-18 13:37:04 +00:00
class ModelResponse(ModelModel):
pass
2024-05-24 07:26:00 +00:00
class ModelForm(BaseModel):
id: str
base_model_id: Optional[str] = None
name: str
icon_url: Optional[str] = None
provider: Optional[str] = None
description: Optional[str] = None
2024-05-24 07:26:00 +00:00
meta: ModelMeta
params: ModelParams
context_length: int = 4096
tags: Optional[dict] = None
sort_order: int = 0
2024-11-16 02:21:41 +00:00
access_control: Optional[dict] = None
is_active: bool = True
2024-05-24 07:26:00 +00:00
class ModelsTable:
"""
模型数据访问层 - 管理 AI 模型的 CRUD 操作和权限控制
核心功能
1. 模型管理创建查询更新删除模型配置
2. 权限控制基于用户/组的访问权限管理
3. 模型同步与外部模型源 OllamaOpenAI同步模型列表
4. 模型分类区分基础模型base_model和自定义模型
模型类型
- 基础模型base_model_id == None从外部 API 同步的原始模型 gpt-4, llama3
- 自定义模型base_model_id != None用户创建的模型配置指向基础模型
"""
2024-05-25 01:26:36 +00:00
def insert_new_model(
self, form_data: ModelForm, user_id: str
2024-05-25 01:26:36 +00:00
) -> Optional[ModelModel]:
"""
插入新模型配置
Args:
form_data: 模型表单数据包含 idnameparamsmeta
user_id: 创建模型的用户 ID
Returns:
ModelModel: 创建成功返回模型对象失败返回 None
"""
2024-05-25 01:26:36 +00:00
model = ModelModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
2024-05-24 07:26:00 +00:00
try:
2024-07-04 06:32:39 +00:00
with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ModelModel.model_validate(result)
else:
return None
2024-05-25 01:26:36 +00:00
except Exception as e:
log.exception(f"Failed to insert a new model: {e}")
2024-05-24 07:26:00 +00:00
return None
2024-08-14 12:46:31 +00:00
def get_all_models(self) -> list[ModelModel]:
"""
获取所有模型包括基础模型和自定义模型
Returns:
list[ModelModel]: 所有模型列表
"""
2024-07-04 06:32:39 +00:00
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
2024-11-18 13:37:04 +00:00
def get_models(self) -> list[ModelUserResponse]:
"""
获取自定义模型列表 base_model_id != None 的模型
返回结果包含创建者的用户信息用于前端显示模型来源
Returns:
list[ModelUserResponse]: 自定义模型列表附带用户信息
"""
2024-11-15 09:29:07 +00:00
with get_db() as db:
# 只查询自定义模型base_model_id 不为 None
all_models = db.query(Model).filter(Model.base_model_id != None).all()
# 批量获取用户信息,避免 N+1 查询问题
user_ids = list(set(model.user_id for model in all_models))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
# 组装模型和用户信息
2024-11-20 00:47:35 +00:00
models = []
for model in all_models:
user = users_dict.get(model.user_id)
2024-11-20 00:47:35 +00:00
models.append(
ModelUserResponse.model_validate(
{
**ModelModel.model_validate(model).model_dump(),
"user": user.model_dump() if user else None,
}
)
2024-11-18 13:37:04 +00:00
)
2024-11-20 00:47:35 +00:00
return models
2024-11-15 09:29:07 +00:00
2024-11-16 02:53:50 +00:00
def get_base_models(self) -> list[ModelModel]:
"""
获取基础模型列表 base_model_id == None 的模型
基础模型通常从外部 APIOllamaOpenAI 同步而来
Returns:
list[ModelModel]: 基础模型列表
"""
2024-11-16 02:53:50 +00:00
with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all()
]
2024-11-15 09:29:07 +00:00
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
2024-11-18 13:37:04 +00:00
) -> list[ModelUserResponse]:
"""
获取用户有权限访问的模型列表
权限判断逻辑
1. 用户创建的模型user_id 匹配
2. 通过访问控制access_control授予权限的模型
Args:
user_id: 用户 ID
permission: 权限类型"read" "write"默认 "write"
Returns:
list[ModelUserResponse]: 用户有权限访问的模型列表
"""
2024-11-18 13:37:04 +00:00
models = self.get_models()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
2024-11-15 09:29:07 +00:00
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control, user_group_ids)
2024-11-15 09:29:07 +00:00
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
"""
根据 ID 获取模型
Args:
id: 模型 ID
Returns:
ModelModel: 找到返回模型对象未找到返回 None
"""
try:
2024-07-04 06:32:39 +00:00
with get_db() as db:
model = db.get(Model, id)
return ModelModel.model_validate(model)
2024-08-14 12:38:19 +00:00
except Exception:
2024-05-24 07:26:00 +00:00
return None
2024-11-16 02:21:41 +00:00
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
"""
切换模型激活状态启用/禁用
Args:
id: 模型 ID
Returns:
ModelModel: 更新后的模型对象失败返回 None
"""
2024-11-16 02:21:41 +00:00
with get_db() as db:
try:
is_active = db.query(Model).filter_by(id=id).first().is_active
db.query(Model).filter_by(id=id).update(
{
"is_active": not is_active,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_model_by_id(id)
except Exception:
return None
2024-06-24 07:57:08 +00:00
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
"""
更新模型配置
Args:
id: 模型 ID
model: 更新后的模型数据
Returns:
ModelModel: 更新后的模型对象失败返回 None
"""
2024-05-24 07:26:00 +00:00
try:
2024-07-04 06:32:39 +00:00
with get_db() as db:
# 只更新 ModelForm 中包含的字段(排除 id
2024-07-08 18:58:36 +00:00
result = (
db.query(Model)
.filter_by(id=id)
2024-11-16 06:04:33 +00:00
.update(model.model_dump(exclude={"id"}))
2024-07-08 18:58:36 +00:00
)
2024-07-04 06:32:39 +00:00
db.commit()
2024-07-08 18:58:36 +00:00
model = db.get(Model, id)
2024-07-04 06:32:39 +00:00
db.refresh(model)
return ModelModel.model_validate(model)
2024-05-25 05:21:57 +00:00
except Exception as e:
log.exception(f"Failed to update the model by id {id}: {e}")
2024-05-24 07:26:00 +00:00
return None
def delete_model_by_id(self, id: str) -> bool:
"""
删除指定模型
Args:
id: 模型 ID
Returns:
bool: 删除成功返回 True失败返回 False
"""
2024-05-24 07:26:00 +00:00
try:
2024-07-04 06:32:39 +00:00
with get_db() as db:
db.query(Model).filter_by(id=id).delete()
2024-07-06 15:10:58 +00:00
db.commit()
2024-07-04 06:32:39 +00:00
return True
2024-08-14 12:38:19 +00:00
except Exception:
return False
2024-11-19 19:03:36 +00:00
def delete_all_models(self) -> bool:
"""
删除所有模型危险操作通常仅用于测试或重置
Returns:
bool: 删除成功返回 True失败返回 False
"""
2024-11-19 19:03:36 +00:00
try:
with get_db() as db:
db.query(Model).delete()
db.commit()
return True
except Exception:
return False
2025-07-28 09:06:05 +00:00
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
"""
同步模型列表 - 与外部模型源OllamaOpenAI 同步
同步逻辑
1. 更新已存在的模型
2. 插入新模型
3. 删除不再存在的模型
典型使用场景
- Ollama API 获取模型列表后同步到数据库
- OpenAI API 获取模型列表后同步到数据库
Args:
user_id: 执行同步的用户 ID作为模型创建者
models: 外部模型列表
Returns:
list[ModelModel]: 同步后的所有模型列表
"""
2025-07-28 09:06:05 +00:00
try:
with get_db() as db:
# 获取现有模型
2025-07-28 09:06:05 +00:00
existing_models = db.query(Model).all()
existing_ids = {model.id for model in existing_models}
# 准备新模型 ID 集合
2025-07-28 09:06:05 +00:00
new_model_ids = {model.id for model in models}
# 更新或插入模型
2025-07-28 09:06:05 +00:00
for model in models:
if model.id in existing_ids:
# 更新已存在的模型
2025-07-28 09:06:05 +00:00
db.query(Model).filter_by(id=model.id).update(
{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
else:
# 插入新模型
2025-07-28 09:06:05 +00:00
new_model = Model(
**{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
db.add(new_model)
# 删除不再存在的模型
2025-07-28 09:06:05 +00:00
for model in existing_models:
if model.id not in new_model_ids:
db.delete(model)
db.commit()
return [
ModelModel.model_validate(model) for model in db.query(Model).all()
]
except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}")
return []
Models = ModelsTable()