diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index e6b2e9d9e9..293bc14c04 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -38,6 +38,9 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, validator from starlette.background import BackgroundTask +from sqlalchemy.orm import Session + +from open_webui.internal.db import get_session from open_webui.models.models import Models @@ -425,7 +428,7 @@ async def get_filtered_models(models, user, db=None): # Filter models based on user access control filtered_models = [] for model in models.get("models", []): - model_info = Models.get_model_by_id(model["model"]) + model_info = Models.get_model_by_id(model["model"], db=db) if model_info: if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control, db=db @@ -1253,6 +1256,7 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, + db: Session = Depends(get_session), ): if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True @@ -1274,7 +1278,7 @@ async def generate_chat_completion( del payload["metadata"] model_id = payload["model"] - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) if model_info: if model_info.base_model_id: @@ -1298,7 +1302,7 @@ async def generate_chat_completion( if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ) ): raise HTTPException( @@ -1370,6 +1374,7 @@ async def generate_openai_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): metadata = form_data.pop("metadata", None) @@ -1390,7 +1395,7 @@ async def generate_openai_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1404,7 +1409,7 @@ async def generate_openai_completion( if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ) ): raise HTTPException( @@ -1449,6 +1454,7 @@ async def generate_openai_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): metadata = form_data.pop("metadata", None) @@ -1469,7 +1475,7 @@ async def generate_openai_chat_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1487,7 +1493,7 @@ async def generate_openai_chat_completion( if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ) ): raise HTTPException( @@ -1530,6 +1536,7 @@ async def get_openai_models( request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): models = [] @@ -1582,10 +1589,10 @@ async def get_openai_models( # Filter models based on user access control filtered_models = [] for model in models: - model_info = Models.get_model_by_id(model["id"]) + model_info = Models.get_model_by_id(model["id"], db=db) if model_info: if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ): filtered_models.append(model) models = filtered_models diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 3e8a10b9fa..18f937f369 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -19,6 +19,9 @@ from fastapi.responses import ( ) from pydantic import BaseModel from starlette.background import BackgroundTask +from sqlalchemy.orm import Session + +from open_webui.internal.db import get_session from open_webui.models.models import Models from open_webui.config import ( @@ -457,7 +460,7 @@ async def get_filtered_models(models, user, db=None): # Filter models based on user access control filtered_models = [] for model in models.get("data", []): - model_info = Models.get_model_by_id(model["id"]) + model_info = Models.get_model_by_id(model["id"], db=db) if model_info: if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control, db=db @@ -797,6 +800,7 @@ async def generate_chat_completion( form_data: dict, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, + db: Session = Depends(get_session), ): if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True @@ -807,7 +811,7 @@ async def generate_chat_completion( metadata = payload.pop("metadata", None) model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) # Check model info and override the payload if model_info: @@ -833,7 +837,7 @@ async def generate_chat_completion( if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ) ): raise HTTPException(