mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
wip: models
This commit is contained in:
parent
30bd4a2910
commit
446699e415
8 changed files with 69 additions and 57 deletions
|
|
@ -199,7 +199,7 @@ async def generate_function_chat_completion(
|
||||||
return params
|
return params
|
||||||
|
|
||||||
model_id = form_data.get("model")
|
model_id = form_data.get("model")
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = await Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
metadata = form_data.pop("metadata", {})
|
metadata = form_data.pop("metadata", {})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1281,7 +1281,7 @@ async def get_models(
|
||||||
filtered_models.append(model)
|
filtered_models.append(model)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_info = Models.get_model_by_id(model["id"])
|
model_info = await Models.get_model_by_id(model["id"])
|
||||||
if model_info:
|
if model_info:
|
||||||
if (
|
if (
|
||||||
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
||||||
|
|
@ -1401,7 +1401,7 @@ async def chat_completion(
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
||||||
model = request.app.state.MODELS[model_id]
|
model = request.app.state.MODELS[model_id]
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = await Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
# Check if user has access to the model
|
# Check if user has access to the model
|
||||||
if not BYPASS_MODEL_ACCESS_CONTROL and (
|
if not BYPASS_MODEL_ACCESS_CONTROL and (
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ class ModelForm(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ModelsTable:
|
class ModelsTable:
|
||||||
def insert_new_model(
|
async def insert_new_model(
|
||||||
self, form_data: ModelForm, user_id: str
|
self, form_data: ModelForm, user_id: str
|
||||||
) -> Optional[ModelModel]:
|
) -> Optional[ModelModel]:
|
||||||
model = ModelModel(
|
model = ModelModel(
|
||||||
|
|
@ -157,9 +157,9 @@ class ModelsTable:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
result = Model(**model.model_dump())
|
result = Model(**model.model_dump())
|
||||||
db.add(result)
|
await db.add(result)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(result)
|
await db.refresh(result)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
return ModelModel.model_validate(result)
|
return ModelModel.model_validate(result)
|
||||||
|
|
@ -169,14 +169,19 @@ class ModelsTable:
|
||||||
log.exception(f"Failed to insert a new model: {e}")
|
log.exception(f"Failed to insert a new model: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_all_models(self) -> list[ModelModel]:
|
async def get_all_models(self) -> list[ModelModel]:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
return [
|
||||||
|
ModelModel.model_validate(model)
|
||||||
|
for model in await db.query(Model).all()
|
||||||
|
]
|
||||||
|
|
||||||
async def get_models(self) -> list[ModelUserResponse]:
|
async def get_models(self) -> list[ModelUserResponse]:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
models = []
|
models = []
|
||||||
for model in db.query(Model).filter(Model.base_model_id != None).all():
|
for model in (
|
||||||
|
await db.query(Model).filter(Model.base_model_id != None).all()
|
||||||
|
):
|
||||||
user = await Users.get_user_by_id(model.user_id)
|
user = await Users.get_user_by_id(model.user_id)
|
||||||
models.append(
|
models.append(
|
||||||
ModelUserResponse.model_validate(
|
ModelUserResponse.model_validate(
|
||||||
|
|
@ -188,17 +193,19 @@ class ModelsTable:
|
||||||
)
|
)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def get_base_models(self) -> list[ModelModel]:
|
async def get_base_models(self) -> list[ModelModel]:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
return [
|
return [
|
||||||
ModelModel.model_validate(model)
|
ModelModel.model_validate(model)
|
||||||
for model in db.query(Model).filter(Model.base_model_id == None).all()
|
for model in await db.query(Model)
|
||||||
|
.filter(Model.base_model_id == None)
|
||||||
|
.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_models_by_user_id(
|
async def get_models_by_user_id(
|
||||||
self, user_id: str, permission: str = "write"
|
self, user_id: str, permission: str = "write"
|
||||||
) -> list[ModelUserResponse]:
|
) -> list[ModelUserResponse]:
|
||||||
models = self.get_models()
|
models = await self.get_models()
|
||||||
return [
|
return [
|
||||||
model
|
model
|
||||||
for model in models
|
for model in models
|
||||||
|
|
@ -206,74 +213,78 @@ class ModelsTable:
|
||||||
or await has_access(user_id, permission, model.access_control)
|
or await has_access(user_id, permission, model.access_control)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
async def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
model = db.get(Model, id)
|
model = await db.get(Model, id)
|
||||||
return ModelModel.model_validate(model)
|
return ModelModel.model_validate(model)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
async def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
try:
|
try:
|
||||||
is_active = db.query(Model).filter_by(id=id).first().is_active
|
is_active = (await db.query(Model).filter_by(id=id).first()).is_active
|
||||||
|
|
||||||
db.query(Model).filter_by(id=id).update(
|
await db.query(Model).filter_by(id=id).update(
|
||||||
{
|
{
|
||||||
"is_active": not is_active,
|
"is_active": not is_active,
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return self.get_model_by_id(id)
|
return await self.get_model_by_id(id)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
async def update_model_by_id(
|
||||||
|
self, id: str, model: ModelForm
|
||||||
|
) -> Optional[ModelModel]:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
# update only the fields that are present in the model
|
# update only the fields that are present in the model
|
||||||
result = (
|
result = (
|
||||||
db.query(Model)
|
await db.query(Model)
|
||||||
.filter_by(id=id)
|
.filter_by(id=id)
|
||||||
.update(model.model_dump(exclude={"id"}))
|
.update(model.model_dump(exclude={"id"}))
|
||||||
)
|
)
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
model = db.get(Model, id)
|
model = await db.get(Model, id)
|
||||||
db.refresh(model)
|
await db.refresh(model)
|
||||||
return ModelModel.model_validate(model)
|
return ModelModel.model_validate(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_model_by_id(self, id: str) -> bool:
|
async def delete_model_by_id(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
db.query(Model).filter_by(id=id).delete()
|
await db.query(Model).filter_by(id=id).delete()
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_all_models(self) -> bool:
|
async def delete_all_models(self) -> bool:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
db.query(Model).delete()
|
await db.query(Model).delete()
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
|
async def sync_models(
|
||||||
|
self, user_id: str, models: list[ModelModel]
|
||||||
|
) -> list[ModelModel]:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
# Get existing models
|
# Get existing models
|
||||||
existing_models = db.query(Model).all()
|
existing_models = await db.query(Model).all()
|
||||||
existing_ids = {model.id for model in existing_models}
|
existing_ids = {model.id for model in existing_models}
|
||||||
|
|
||||||
# Prepare a set of new model IDs
|
# Prepare a set of new model IDs
|
||||||
|
|
@ -282,7 +293,7 @@ class ModelsTable:
|
||||||
# Update or insert models
|
# Update or insert models
|
||||||
for model in models:
|
for model in models:
|
||||||
if model.id in existing_ids:
|
if model.id in existing_ids:
|
||||||
db.query(Model).filter_by(id=model.id).update(
|
await db.query(Model).filter_by(id=model.id).update(
|
||||||
{
|
{
|
||||||
**model.model_dump(),
|
**model.model_dump(),
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
@ -297,17 +308,18 @@ class ModelsTable:
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.add(new_model)
|
await db.add(new_model)
|
||||||
|
|
||||||
# Remove models that are no longer present
|
# Remove models that are no longer present
|
||||||
for model in existing_models:
|
for model in existing_models:
|
||||||
if model.id not in new_model_ids:
|
if model.id not in new_model_ids:
|
||||||
db.delete(model)
|
await db.delete(model)
|
||||||
|
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelModel.model_validate(model) for model in db.query(Model).all()
|
ModelModel.model_validate(model)
|
||||||
|
for model in await db.query(Model).all()
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error syncing models for user {user_id}: {e}")
|
log.exception(f"Error syncing models for user {user_id}: {e}")
|
||||||
|
|
|
||||||
|
|
@ -602,7 +602,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
|
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
|
||||||
|
|
||||||
# Get all models
|
# Get all models
|
||||||
models = Models.get_all_models()
|
models = await Models.get_all_models()
|
||||||
log.info(f"Found {len(models)} models to check for knowledge base {id}")
|
log.info(f"Found {len(models)} models to check for knowledge base {id}")
|
||||||
|
|
||||||
# Update models that reference this knowledge base
|
# Update models that reference this knowledge base
|
||||||
|
|
@ -626,7 +626,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
access_control=model.access_control,
|
access_control=model.access_control,
|
||||||
is_active=model.is_active,
|
is_active=model.is_active,
|
||||||
)
|
)
|
||||||
Models.update_model_by_id(model.id, model_form)
|
await Models.update_model_by_id(model.id, model_form)
|
||||||
|
|
||||||
# Clean up vector DB
|
# Clean up vector DB
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,9 @@ router = APIRouter()
|
||||||
@router.get("/", response_model=list[ModelUserResponse])
|
@router.get("/", response_model=list[ModelUserResponse])
|
||||||
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
||||||
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
return Models.get_models()
|
return await Models.get_models()
|
||||||
else:
|
else:
|
||||||
return Models.get_models_by_user_id(user.id)
|
return await Models.get_models_by_user_id(user.id)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
|
|
@ -40,7 +40,7 @@ async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.get("/base", response_model=list[ModelResponse])
|
@router.get("/base", response_model=list[ModelResponse])
|
||||||
async def get_base_models(user=Depends(get_admin_user)):
|
async def get_base_models(user=Depends(get_admin_user)):
|
||||||
return Models.get_base_models()
|
return await Models.get_base_models()
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -62,7 +62,7 @@ async def create_new_model(
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Models.get_model_by_id(form_data.id)
|
model = await Models.get_model_by_id(form_data.id)
|
||||||
if model:
|
if model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -70,7 +70,7 @@ async def create_new_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = Models.insert_new_model(form_data, user.id)
|
model = await Models.insert_new_model(form_data, user.id)
|
||||||
if model:
|
if model:
|
||||||
return model
|
return model
|
||||||
else:
|
else:
|
||||||
|
|
@ -87,7 +87,7 @@ async def create_new_model(
|
||||||
|
|
||||||
@router.get("/export", response_model=list[ModelModel])
|
@router.get("/export", response_model=list[ModelModel])
|
||||||
async def export_models(user=Depends(get_admin_user)):
|
async def export_models(user=Depends(get_admin_user)):
|
||||||
return Models.get_models()
|
return await Models.get_models()
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -103,7 +103,7 @@ class SyncModelsForm(BaseModel):
|
||||||
async def sync_models(
|
async def sync_models(
|
||||||
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
return Models.sync_models(user.id, form_data.models)
|
return await Models.sync_models(user.id, form_data.models)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
|
|
@ -114,7 +114,7 @@ async def sync_models(
|
||||||
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
||||||
@router.get("/model", response_model=Optional[ModelResponse])
|
@router.get("/model", response_model=Optional[ModelResponse])
|
||||||
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
model = Models.get_model_by_id(id)
|
model = await Models.get_model_by_id(id)
|
||||||
if model:
|
if model:
|
||||||
if (
|
if (
|
||||||
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
||||||
|
|
@ -136,14 +136,14 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
||||||
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
model = Models.get_model_by_id(id)
|
model = await Models.get_model_by_id(id)
|
||||||
if model:
|
if model:
|
||||||
if (
|
if (
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or model.user_id == user.id
|
or model.user_id == user.id
|
||||||
or await has_access(user.id, "write", model.access_control)
|
or await has_access(user.id, "write", model.access_control)
|
||||||
):
|
):
|
||||||
model = Models.toggle_model_by_id(id)
|
model = await Models.toggle_model_by_id(id)
|
||||||
|
|
||||||
if model:
|
if model:
|
||||||
return model
|
return model
|
||||||
|
|
@ -175,7 +175,7 @@ async def update_model_by_id(
|
||||||
form_data: ModelForm,
|
form_data: ModelForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
model = Models.get_model_by_id(id)
|
model = await Models.get_model_by_id(id)
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -193,7 +193,7 @@ async def update_model_by_id(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Models.update_model_by_id(id, form_data)
|
model = await Models.update_model_by_id(id, form_data)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -441,7 +441,7 @@ async def get_filtered_models(models, user):
|
||||||
# Filter models based on user access control
|
# Filter models based on user access control
|
||||||
filtered_models = []
|
filtered_models = []
|
||||||
for model in models.get("models", []):
|
for model in models.get("models", []):
|
||||||
model_info = Models.get_model_by_id(model["model"])
|
model_info = await Models.get_model_by_id(model["model"])
|
||||||
if model_info:
|
if model_info:
|
||||||
if user.id == model_info.user_id or (
|
if user.id == model_info.user_id or (
|
||||||
await has_access(
|
await has_access(
|
||||||
|
|
@ -1320,7 +1320,7 @@ async def generate_chat_completion(
|
||||||
del payload["metadata"]
|
del payload["metadata"]
|
||||||
|
|
||||||
model_id = payload["model"]
|
model_id = payload["model"]
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = await Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
|
|
|
||||||
|
|
@ -383,7 +383,7 @@ async def get_filtered_models(models, user):
|
||||||
# Filter models based on user access control
|
# Filter models based on user access control
|
||||||
filtered_models = []
|
filtered_models = []
|
||||||
for model in models.get("data", []):
|
for model in models.get("data", []):
|
||||||
model_info = Models.get_model_by_id(model["id"])
|
model_info = await Models.get_model_by_id(model["id"])
|
||||||
if model_info:
|
if model_info:
|
||||||
if user.id == model_info.user_id or (
|
if user.id == model_info.user_id or (
|
||||||
await has_access(
|
await has_access(
|
||||||
|
|
@ -738,7 +738,7 @@ async def generate_chat_completion(
|
||||||
metadata = payload.pop("metadata", None)
|
metadata = payload.pop("metadata", None)
|
||||||
|
|
||||||
model_id = form_data.get("model")
|
model_id = form_data.get("model")
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = await Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
# Check model info and override the payload
|
# Check model info and override the payload
|
||||||
if model_info:
|
if model_info:
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
custom_models = Models.get_all_models()
|
custom_models = await Models.get_all_models()
|
||||||
for custom_model in custom_models:
|
for custom_model in custom_models:
|
||||||
if custom_model.base_model_id is None:
|
if custom_model.base_model_id is None:
|
||||||
# Applied directly to a base model
|
# Applied directly to a base model
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue