perf: fix N+1 query issues in user group access control validation

- Pre-fetch user group IDs in get_*_by_user_id methods across models layer
- Pass user_group_ids to has_access to avoid repeated group queries
- Reduce query count from 1+N to 1+1 pattern for access control validation
- Apply consistent optimization across knowledge, models, notes, prompts, and tools

Signed-off-by: Sihyeon Jang <sihyeon.jang@navercorp.com>
This commit is contained in:
Sihyeon Jang 2025-09-03 05:56:48 +09:00
parent 22c4ef4fb0
commit eff06538a6
5 changed files with 14 additions and 5 deletions

View file

@ -8,6 +8,7 @@ from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse from open_webui.models.files import FileMetadataResponse
from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
@ -147,11 +148,12 @@ class KnowledgeTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]: ) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases() knowledge_bases = self.get_knowledge_bases()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
knowledge_base knowledge_base
for knowledge_base in knowledge_bases for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id if knowledge_base.user_id == user_id
or has_access(user_id, permission, knowledge_base.access_control) or has_access(user_id, permission, knowledge_base.access_control, user_group_ids)
] ]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:

View file

@ -5,6 +5,7 @@ from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
@ -199,11 +200,12 @@ class ModelsTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]: ) -> list[ModelUserResponse]:
models = self.get_models() models = self.get_models()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
model model
for model in models for model in models
if model.user_id == user_id if model.user_id == user_id
or has_access(user_id, permission, model.access_control) or has_access(user_id, permission, model.access_control, user_group_ids)
] ]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:

View file

@ -4,6 +4,7 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
@ -105,11 +106,12 @@ class NoteTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[NoteModel]: ) -> list[NoteModel]:
notes = self.get_notes() notes = self.get_notes()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
note note
for note in notes for note in notes
if note.user_id == user_id if note.user_id == user_id
or has_access(user_id, permission, note.access_control) or has_access(user_id, permission, note.access_control, user_group_ids)
] ]
def get_note_by_id(self, id: str) -> Optional[NoteModel]: def get_note_by_id(self, id: str) -> Optional[NoteModel]:

View file

@ -2,6 +2,7 @@ import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -122,12 +123,13 @@ class PromptsTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[PromptUserResponse]: ) -> list[PromptUserResponse]:
prompts = self.get_prompts() prompts = self.get_prompts()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
prompt prompt
for prompt in prompts for prompt in prompts
if prompt.user_id == user_id if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control) or has_access(user_id, permission, prompt.access_control, user_group_ids)
] ]
def update_prompt_by_command( def update_prompt_by_command(

View file

@ -161,12 +161,13 @@ class ToolsTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[ToolUserModel]: ) -> list[ToolUserModel]:
tools = self.get_tools() tools = self.get_tools()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
tool tool
for tool in tools for tool in tools
if tool.user_id == user_id if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control) or has_access(user_id, permission, tool.access_control, user_group_ids)
] ]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]: