wip: models

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:34:49 +04:00
parent 30bd4a2910
commit 446699e415
8 changed files with 69 additions and 57 deletions

View file

@ -199,7 +199,7 @@ async def generate_function_chat_completion(
return params
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
model_info = await Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})

View file

@ -1281,7 +1281,7 @@ async def get_models(
filtered_models.append(model)
continue
model_info = Models.get_model_by_id(model["id"])
model_info = await Models.get_model_by_id(model["id"])
if model_info:
if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
@ -1401,7 +1401,7 @@ async def chat_completion(
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
model_info = await Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and (

View file

@ -143,7 +143,7 @@ class ModelForm(BaseModel):
class ModelsTable:
def insert_new_model(
async def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
model = ModelModel(
@ -157,9 +157,9 @@ class ModelsTable:
try:
async with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
db.refresh(result)
await db.add(result)
await db.commit()
await db.refresh(result)
if result:
return ModelModel.model_validate(result)
@ -169,14 +169,19 @@ class ModelsTable:
log.exception(f"Failed to insert a new model: {e}")
return None
def get_all_models(self) -> list[ModelModel]:
async def get_all_models(self) -> list[ModelModel]:
async with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
return [
ModelModel.model_validate(model)
for model in await db.query(Model).all()
]
async def get_models(self) -> list[ModelUserResponse]:
async with get_db() as db:
models = []
for model in db.query(Model).filter(Model.base_model_id != None).all():
for model in (
await db.query(Model).filter(Model.base_model_id != None).all()
):
user = await Users.get_user_by_id(model.user_id)
models.append(
ModelUserResponse.model_validate(
@ -188,17 +193,19 @@ class ModelsTable:
)
return models
def get_base_models(self) -> list[ModelModel]:
async def get_base_models(self) -> list[ModelModel]:
async with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all()
for model in await db.query(Model)
.filter(Model.base_model_id == None)
.all()
]
async def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]:
models = self.get_models()
models = await self.get_models()
return [
model
for model in models
@ -206,74 +213,78 @@ class ModelsTable:
or await has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
async def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
async with get_db() as db:
model = db.get(Model, id)
model = await db.get(Model, id)
return ModelModel.model_validate(model)
except Exception:
return None
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
async def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
async with get_db() as db:
try:
is_active = db.query(Model).filter_by(id=id).first().is_active
is_active = (await db.query(Model).filter_by(id=id).first()).is_active
db.query(Model).filter_by(id=id).update(
await db.query(Model).filter_by(id=id).update(
{
"is_active": not is_active,
"updated_at": int(time.time()),
}
)
db.commit()
await db.commit()
return self.get_model_by_id(id)
return await self.get_model_by_id(id)
except Exception:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
async def update_model_by_id(
self, id: str, model: ModelForm
) -> Optional[ModelModel]:
try:
async with get_db() as db:
# update only the fields that are present in the model
result = (
db.query(Model)
await db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}))
)
db.commit()
await db.commit()
model = db.get(Model, id)
db.refresh(model)
model = await db.get(Model, id)
await db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
log.exception(f"Failed to update the model by id {id}: {e}")
return None
def delete_model_by_id(self, id: str) -> bool:
async def delete_model_by_id(self, id: str) -> bool:
try:
async with get_db() as db:
db.query(Model).filter_by(id=id).delete()
db.commit()
await db.query(Model).filter_by(id=id).delete()
await db.commit()
return True
except Exception:
return False
def delete_all_models(self) -> bool:
async def delete_all_models(self) -> bool:
try:
async with get_db() as db:
db.query(Model).delete()
db.commit()
await db.query(Model).delete()
await db.commit()
return True
except Exception:
return False
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
async def sync_models(
self, user_id: str, models: list[ModelModel]
) -> list[ModelModel]:
try:
async with get_db() as db:
# Get existing models
existing_models = db.query(Model).all()
existing_models = await db.query(Model).all()
existing_ids = {model.id for model in existing_models}
# Prepare a set of new model IDs
@ -282,7 +293,7 @@ class ModelsTable:
# Update or insert models
for model in models:
if model.id in existing_ids:
db.query(Model).filter_by(id=model.id).update(
await db.query(Model).filter_by(id=model.id).update(
{
**model.model_dump(),
"user_id": user_id,
@ -297,17 +308,18 @@ class ModelsTable:
"updated_at": int(time.time()),
}
)
db.add(new_model)
await db.add(new_model)
# Remove models that are no longer present
for model in existing_models:
if model.id not in new_model_ids:
db.delete(model)
await db.delete(model)
db.commit()
await db.commit()
return [
ModelModel.model_validate(model) for model in db.query(Model).all()
ModelModel.model_validate(model)
for model in await db.query(Model).all()
]
except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}")

View file

@ -602,7 +602,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
# Get all models
models = Models.get_all_models()
models = await Models.get_all_models()
log.info(f"Found {len(models)} models to check for knowledge base {id}")
# Update models that reference this knowledge base
@ -626,7 +626,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
access_control=model.access_control,
is_active=model.is_active,
)
Models.update_model_by_id(model.id, model_form)
await Models.update_model_by_id(model.id, model_form)
# Clean up vector DB
try:

View file

@ -28,9 +28,9 @@ router = APIRouter()
@router.get("/", response_model=list[ModelUserResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
return Models.get_models()
return await Models.get_models()
else:
return Models.get_models_by_user_id(user.id)
return await Models.get_models_by_user_id(user.id)
###########################
@ -40,7 +40,7 @@ async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
@router.get("/base", response_model=list[ModelResponse])
async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models()
return await Models.get_base_models()
############################
@ -62,7 +62,7 @@ async def create_new_model(
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
model = Models.get_model_by_id(form_data.id)
model = await Models.get_model_by_id(form_data.id)
if model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -70,7 +70,7 @@ async def create_new_model(
)
else:
model = Models.insert_new_model(form_data, user.id)
model = await Models.insert_new_model(form_data, user.id)
if model:
return model
else:
@ -87,7 +87,7 @@ async def create_new_model(
@router.get("/export", response_model=list[ModelModel])
async def export_models(user=Depends(get_admin_user)):
return Models.get_models()
return await Models.get_models()
############################
@ -103,7 +103,7 @@ class SyncModelsForm(BaseModel):
async def sync_models(
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
):
return Models.sync_models(user.id, form_data.models)
return await Models.sync_models(user.id, form_data.models)
###########################
@ -114,7 +114,7 @@ async def sync_models(
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
model = await Models.get_model_by_id(id)
if model:
if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
@ -136,14 +136,14 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/model/toggle", response_model=Optional[ModelResponse])
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
model = await Models.get_model_by_id(id)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or await has_access(user.id, "write", model.access_control)
):
model = Models.toggle_model_by_id(id)
model = await Models.toggle_model_by_id(id)
if model:
return model
@ -175,7 +175,7 @@ async def update_model_by_id(
form_data: ModelForm,
user=Depends(get_verified_user),
):
model = Models.get_model_by_id(id)
model = await Models.get_model_by_id(id)
if not model:
raise HTTPException(
@ -193,7 +193,7 @@ async def update_model_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
model = Models.update_model_by_id(id, form_data)
model = await Models.update_model_by_id(id, form_data)
return model

View file

@ -441,7 +441,7 @@ async def get_filtered_models(models, user):
# Filter models based on user access control
filtered_models = []
for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"])
model_info = await Models.get_model_by_id(model["model"])
if model_info:
if user.id == model_info.user_id or (
await has_access(
@ -1320,7 +1320,7 @@ async def generate_chat_completion(
del payload["metadata"]
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
model_info = await Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:

View file

@ -383,7 +383,7 @@ async def get_filtered_models(models, user):
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
model_info = await Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or (
await has_access(
@ -738,7 +738,7 @@ async def generate_chat_completion(
metadata = payload.pop("metadata", None)
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
model_info = await Models.get_model_by_id(model_id)
# Check model info and override the payload
if model_info:

View file

@ -145,7 +145,7 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
for function in Functions.get_functions_by_type("filter", active_only=True)
]
custom_models = Models.get_all_models()
custom_models = await Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id is None:
# Applied directly to a base model