mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
perf: fix N+1 query issues in user group access control validation
- Pre-fetch user group IDs in get_*_by_user_id methods across models layer - Pass user_group_ids to has_access to avoid repeated group queries - Reduce query count from 1+N to 1+1 pattern for access control validation - Apply consistent optimization across knowledge, models, notes, prompts, and tools Signed-off-by: Sihyeon Jang <sihyeon.jang@navercorp.com>
This commit is contained in:
parent
22c4ef4fb0
commit
eff06538a6
5 changed files with 14 additions and 5 deletions
|
|
@ -8,6 +8,7 @@ from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.models.files import FileMetadataResponse
|
from open_webui.models.files import FileMetadataResponse
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.users import Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -147,11 +148,12 @@ class KnowledgeTable:
|
||||||
self, user_id: str, permission: str = "write"
|
self, user_id: str, permission: str = "write"
|
||||||
) -> list[KnowledgeUserModel]:
|
) -> list[KnowledgeUserModel]:
|
||||||
knowledge_bases = self.get_knowledge_bases()
|
knowledge_bases = self.get_knowledge_bases()
|
||||||
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||||
return [
|
return [
|
||||||
knowledge_base
|
knowledge_base
|
||||||
for knowledge_base in knowledge_bases
|
for knowledge_base in knowledge_bases
|
||||||
if knowledge_base.user_id == user_id
|
if knowledge_base.user_id == user_id
|
||||||
or has_access(user_id, permission, knowledge_base.access_control)
|
or has_access(user_id, permission, knowledge_base.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||||
from open_webui.internal.db import Base, JSONField, get_db
|
from open_webui.internal.db import Base, JSONField, get_db
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.users import Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -199,11 +200,12 @@ class ModelsTable:
|
||||||
self, user_id: str, permission: str = "write"
|
self, user_id: str, permission: str = "write"
|
||||||
) -> list[ModelUserResponse]:
|
) -> list[ModelUserResponse]:
|
||||||
models = self.get_models()
|
models = self.get_models()
|
||||||
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||||
return [
|
return [
|
||||||
model
|
model
|
||||||
for model in models
|
for model in models
|
||||||
if model.user_id == user_id
|
if model.user_id == user_id
|
||||||
or has_access(user_id, permission, model.access_control)
|
or has_access(user_id, permission, model.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
from open_webui.models.users import Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
|
|
@ -105,11 +106,12 @@ class NoteTable:
|
||||||
self, user_id: str, permission: str = "write"
|
self, user_id: str, permission: str = "write"
|
||||||
) -> list[NoteModel]:
|
) -> list[NoteModel]:
|
||||||
notes = self.get_notes()
|
notes = self.get_notes()
|
||||||
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||||
return [
|
return [
|
||||||
note
|
note
|
||||||
for note in notes
|
for note in notes
|
||||||
if note.user_id == user_id
|
if note.user_id == user_id
|
||||||
or has_access(user_id, permission, note.access_control)
|
or has_access(user_id, permission, note.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.users import Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
@ -122,12 +123,13 @@ class PromptsTable:
|
||||||
self, user_id: str, permission: str = "write"
|
self, user_id: str, permission: str = "write"
|
||||||
) -> list[PromptUserResponse]:
|
) -> list[PromptUserResponse]:
|
||||||
prompts = self.get_prompts()
|
prompts = self.get_prompts()
|
||||||
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||||
|
|
||||||
return [
|
return [
|
||||||
prompt
|
prompt
|
||||||
for prompt in prompts
|
for prompt in prompts
|
||||||
if prompt.user_id == user_id
|
if prompt.user_id == user_id
|
||||||
or has_access(user_id, permission, prompt.access_control)
|
or has_access(user_id, permission, prompt.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
def update_prompt_by_command(
|
def update_prompt_by_command(
|
||||||
|
|
|
||||||
|
|
@ -161,12 +161,13 @@ class ToolsTable:
|
||||||
self, user_id: str, permission: str = "write"
|
self, user_id: str, permission: str = "write"
|
||||||
) -> list[ToolUserModel]:
|
) -> list[ToolUserModel]:
|
||||||
tools = self.get_tools()
|
tools = self.get_tools()
|
||||||
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||||
|
|
||||||
return [
|
return [
|
||||||
tool
|
tool
|
||||||
for tool in tools
|
for tool in tools
|
||||||
if tool.user_id == user_id
|
if tool.user_id == user_id
|
||||||
or has_access(user_id, permission, tool.access_control)
|
or has_access(user_id, permission, tool.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue