mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
wip: models
This commit is contained in:
parent
30bd4a2910
commit
446699e415
8 changed files with 69 additions and 57 deletions
|
|
@ -199,7 +199,7 @@ async def generate_function_chat_completion(
|
|||
return params
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
model_info = await Models.get_model_by_id(model_id)
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
|
|
|
|||
|
|
@ -1281,7 +1281,7 @@ async def get_models(
|
|||
filtered_models.append(model)
|
||||
continue
|
||||
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
model_info = await Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if (
|
||||
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
||||
|
|
@ -1401,7 +1401,7 @@ async def chat_completion(
|
|||
raise Exception("Model not found")
|
||||
|
||||
model = request.app.state.MODELS[model_id]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
model_info = await Models.get_model_by_id(model_id)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and (
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class ModelForm(BaseModel):
|
|||
|
||||
|
||||
class ModelsTable:
|
||||
def insert_new_model(
|
||||
async def insert_new_model(
|
||||
self, form_data: ModelForm, user_id: str
|
||||
) -> Optional[ModelModel]:
|
||||
model = ModelModel(
|
||||
|
|
@ -157,9 +157,9 @@ class ModelsTable:
|
|||
try:
|
||||
async with get_db() as db:
|
||||
result = Model(**model.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
await db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
|
||||
if result:
|
||||
return ModelModel.model_validate(result)
|
||||
|
|
@ -169,14 +169,19 @@ class ModelsTable:
|
|||
log.exception(f"Failed to insert a new model: {e}")
|
||||
return None
|
||||
|
||||
def get_all_models(self) -> list[ModelModel]:
|
||||
async def get_all_models(self) -> list[ModelModel]:
|
||||
async with get_db() as db:
|
||||
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
||||
return [
|
||||
ModelModel.model_validate(model)
|
||||
for model in await db.query(Model).all()
|
||||
]
|
||||
|
||||
async def get_models(self) -> list[ModelUserResponse]:
|
||||
async with get_db() as db:
|
||||
models = []
|
||||
for model in db.query(Model).filter(Model.base_model_id != None).all():
|
||||
for model in (
|
||||
await db.query(Model).filter(Model.base_model_id != None).all()
|
||||
):
|
||||
user = await Users.get_user_by_id(model.user_id)
|
||||
models.append(
|
||||
ModelUserResponse.model_validate(
|
||||
|
|
@ -188,17 +193,19 @@ class ModelsTable:
|
|||
)
|
||||
return models
|
||||
|
||||
def get_base_models(self) -> list[ModelModel]:
|
||||
async def get_base_models(self) -> list[ModelModel]:
|
||||
async with get_db() as db:
|
||||
return [
|
||||
ModelModel.model_validate(model)
|
||||
for model in db.query(Model).filter(Model.base_model_id == None).all()
|
||||
for model in await db.query(Model)
|
||||
.filter(Model.base_model_id == None)
|
||||
.all()
|
||||
]
|
||||
|
||||
async def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ModelUserResponse]:
|
||||
models = self.get_models()
|
||||
models = await self.get_models()
|
||||
return [
|
||||
model
|
||||
for model in models
|
||||
|
|
@ -206,74 +213,78 @@ class ModelsTable:
|
|||
or await has_access(user_id, permission, model.access_control)
|
||||
]
|
||||
|
||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
async def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
model = db.get(Model, id)
|
||||
model = await db.get(Model, id)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
async def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
async with get_db() as db:
|
||||
try:
|
||||
is_active = db.query(Model).filter_by(id=id).first().is_active
|
||||
is_active = (await db.query(Model).filter_by(id=id).first()).is_active
|
||||
|
||||
db.query(Model).filter_by(id=id).update(
|
||||
await db.query(Model).filter_by(id=id).update(
|
||||
{
|
||||
"is_active": not is_active,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
return self.get_model_by_id(id)
|
||||
return await self.get_model_by_id(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
||||
async def update_model_by_id(
|
||||
self, id: str, model: ModelForm
|
||||
) -> Optional[ModelModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
# update only the fields that are present in the model
|
||||
result = (
|
||||
db.query(Model)
|
||||
await db.query(Model)
|
||||
.filter_by(id=id)
|
||||
.update(model.model_dump(exclude={"id"}))
|
||||
)
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
model = db.get(Model, id)
|
||||
db.refresh(model)
|
||||
model = await db.get(Model, id)
|
||||
await db.refresh(model)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
async def delete_model_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Model).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
await db.query(Model).filter_by(id=id).delete()
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_models(self) -> bool:
|
||||
async def delete_all_models(self) -> bool:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Model).delete()
|
||||
db.commit()
|
||||
await db.query(Model).delete()
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
|
||||
async def sync_models(
|
||||
self, user_id: str, models: list[ModelModel]
|
||||
) -> list[ModelModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
# Get existing models
|
||||
existing_models = db.query(Model).all()
|
||||
existing_models = await db.query(Model).all()
|
||||
existing_ids = {model.id for model in existing_models}
|
||||
|
||||
# Prepare a set of new model IDs
|
||||
|
|
@ -282,7 +293,7 @@ class ModelsTable:
|
|||
# Update or insert models
|
||||
for model in models:
|
||||
if model.id in existing_ids:
|
||||
db.query(Model).filter_by(id=model.id).update(
|
||||
await db.query(Model).filter_by(id=model.id).update(
|
||||
{
|
||||
**model.model_dump(),
|
||||
"user_id": user_id,
|
||||
|
|
@ -297,17 +308,18 @@ class ModelsTable:
|
|||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(new_model)
|
||||
await db.add(new_model)
|
||||
|
||||
# Remove models that are no longer present
|
||||
for model in existing_models:
|
||||
if model.id not in new_model_ids:
|
||||
db.delete(model)
|
||||
await db.delete(model)
|
||||
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
return [
|
||||
ModelModel.model_validate(model) for model in db.query(Model).all()
|
||||
ModelModel.model_validate(model)
|
||||
for model in await db.query(Model).all()
|
||||
]
|
||||
except Exception as e:
|
||||
log.exception(f"Error syncing models for user {user_id}: {e}")
|
||||
|
|
|
|||
|
|
@ -602,7 +602,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
|
||||
|
||||
# Get all models
|
||||
models = Models.get_all_models()
|
||||
models = await Models.get_all_models()
|
||||
log.info(f"Found {len(models)} models to check for knowledge base {id}")
|
||||
|
||||
# Update models that reference this knowledge base
|
||||
|
|
@ -626,7 +626,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
access_control=model.access_control,
|
||||
is_active=model.is_active,
|
||||
)
|
||||
Models.update_model_by_id(model.id, model_form)
|
||||
await Models.update_model_by_id(model.id, model_form)
|
||||
|
||||
# Clean up vector DB
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ router = APIRouter()
|
|||
@router.get("/", response_model=list[ModelUserResponse])
|
||||
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
||||
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||
return Models.get_models()
|
||||
return await Models.get_models()
|
||||
else:
|
||||
return Models.get_models_by_user_id(user.id)
|
||||
return await Models.get_models_by_user_id(user.id)
|
||||
|
||||
|
||||
###########################
|
||||
|
|
@ -40,7 +40,7 @@ async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/base", response_model=list[ModelResponse])
|
||||
async def get_base_models(user=Depends(get_admin_user)):
|
||||
return Models.get_base_models()
|
||||
return await Models.get_base_models()
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -62,7 +62,7 @@ async def create_new_model(
|
|||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
model = Models.get_model_by_id(form_data.id)
|
||||
model = await Models.get_model_by_id(form_data.id)
|
||||
if model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -70,7 +70,7 @@ async def create_new_model(
|
|||
)
|
||||
|
||||
else:
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
model = await Models.insert_new_model(form_data, user.id)
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
|
|
@ -87,7 +87,7 @@ async def create_new_model(
|
|||
|
||||
@router.get("/export", response_model=list[ModelModel])
|
||||
async def export_models(user=Depends(get_admin_user)):
|
||||
return Models.get_models()
|
||||
return await Models.get_models()
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -103,7 +103,7 @@ class SyncModelsForm(BaseModel):
|
|||
async def sync_models(
|
||||
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
return Models.sync_models(user.id, form_data.models)
|
||||
return await Models.sync_models(user.id, form_data.models)
|
||||
|
||||
|
||||
###########################
|
||||
|
|
@ -114,7 +114,7 @@ async def sync_models(
|
|||
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
||||
@router.get("/model", response_model=Optional[ModelResponse])
|
||||
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
model = await Models.get_model_by_id(id)
|
||||
if model:
|
||||
if (
|
||||
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
||||
|
|
@ -136,14 +136,14 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
||||
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
model = await Models.get_model_by_id(id)
|
||||
if model:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or model.user_id == user.id
|
||||
or await has_access(user.id, "write", model.access_control)
|
||||
):
|
||||
model = Models.toggle_model_by_id(id)
|
||||
model = await Models.toggle_model_by_id(id)
|
||||
|
||||
if model:
|
||||
return model
|
||||
|
|
@ -175,7 +175,7 @@ async def update_model_by_id(
|
|||
form_data: ModelForm,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
model = Models.get_model_by_id(id)
|
||||
model = await Models.get_model_by_id(id)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -193,7 +193,7 @@ async def update_model_by_id(
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
model = Models.update_model_by_id(id, form_data)
|
||||
model = await Models.update_model_by_id(id, form_data)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ async def get_filtered_models(models, user):
|
|||
# Filter models based on user access control
|
||||
filtered_models = []
|
||||
for model in models.get("models", []):
|
||||
model_info = Models.get_model_by_id(model["model"])
|
||||
model_info = await Models.get_model_by_id(model["model"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or (
|
||||
await has_access(
|
||||
|
|
@ -1320,7 +1320,7 @@ async def generate_chat_completion(
|
|||
del payload["metadata"]
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
model_info = await Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ async def get_filtered_models(models, user):
|
|||
# Filter models based on user access control
|
||||
filtered_models = []
|
||||
for model in models.get("data", []):
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
model_info = await Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or (
|
||||
await has_access(
|
||||
|
|
@ -738,7 +738,7 @@ async def generate_chat_completion(
|
|||
metadata = payload.pop("metadata", None)
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
model_info = await Models.get_model_by_id(model_id)
|
||||
|
||||
# Check model info and override the payload
|
||||
if model_info:
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
|||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
custom_models = Models.get_all_models()
|
||||
custom_models = await Models.get_all_models()
|
||||
for custom_model in custom_models:
|
||||
if custom_model.base_model_id is None:
|
||||
# Applied directly to a base model
|
||||
|
|
|
|||
Loading…
Reference in a new issue