diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index e566e032bd..07bf43510b 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -13,9 +13,7 @@ import requests from open_webui.apps.webui.models.models import Models from open_webui.config import ( CORS_ALLOW_ORIGIN, - ENABLE_MODEL_FILTER, ENABLE_OLLAMA_API, - MODEL_FILTER_LIST, OLLAMA_BASE_URLS, OLLAMA_API_CONFIGS, UPLOAD_DIR, @@ -66,32 +64,16 @@ app.add_middleware( app.state.config = AppConfig() -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS -app.state.MODELS = {} - # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - - response = await call_next(request) - return response - - @app.head("/") @app.get("/") async def get_status(): @@ -326,8 +308,6 @@ async def get_all_models(): else: models = {"models": []} - app.state.MODELS = {model["model"]: model for model in models["models"]} - return models @@ -339,16 +319,18 @@ async def get_ollama_tags( if url_idx is None: models = await get_all_models() - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["models"] = list( - filter( - lambda model: model["name"] - in app.state.config.MODEL_FILTER_LIST, - models["models"], - ) - ) - return models + # TODO: Check User Group and Filter Models + # if app.state.config.ENABLE_MODEL_FILTER: + # if user.role == "user": + # models["models"] = list( + # filter( + # lambda model: model["name"] + # in app.state.config.MODEL_FILTER_LIST, + # models["models"], + # ) + # ) + # return models + return models else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -473,8 +455,11 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, @@ -523,8 +508,11 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.source in app.state.MODELS: - url_idx = app.state.MODELS[form_data.source]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.source in models: + url_idx = models[form_data.source]["urls"][0] else: raise HTTPException( status_code=400, @@ -579,8 +567,11 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, @@ -628,13 +619,16 @@ async def delete_model( @app.post("/api/show") async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): - if form_data.name not in app.state.MODELS: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) + url_idx = random.choice(models[form_data.name]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") @@ -704,23 +698,26 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) + return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) -def generate_ollama_embeddings( +async def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -771,20 +768,23 @@ def generate_ollama_embeddings( ) -def generate_ollama_batch_embeddings( +async def generate_ollama_batch_embeddings( form_data: GenerateEmbedForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -854,13 +854,16 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -895,14 +898,17 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -def get_ollama_url(url_idx: Optional[int], model: str): +async def get_ollama_url(url_idx: Optional[int], model: str): if url_idx is None: - if model not in app.state.MODELS: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if model not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(app.state.MODELS[model]["urls"]) + url_idx = random.choice(models[model]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] return url @@ -922,12 +928,14 @@ async def generate_chat_completion( model_id = form_data.model - if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=403, - detail="Model not found", - ) + # TODO: Check User Group and Filter Models + # if not bypass_filter: + # if app.state.config.ENABLE_MODEL_FILTER: + # if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: + # raise HTTPException( + # status_code=403, + # detail="Model not found", + # ) model_info = Models.get_model_by_id(model_id) @@ -949,7 +957,7 @@ async def generate_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = get_ollama_url(url_idx, payload["model"]) + url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") @@ -1008,12 +1016,13 @@ async def generate_openai_chat_completion( model_id = completion_form.model - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - raise HTTPException( - status_code=403, - detail="Model not found", - ) + # TODO: Check User Group and Filter Models + # if app.state.config.ENABLE_MODEL_FILTER: + # if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: + # raise HTTPException( + # status_code=403, + # detail="Model not found", + # ) model_info = Models.get_model_by_id(model_id) @@ -1030,7 +1039,7 @@ async def generate_openai_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = get_ollama_url(url_idx, payload["model"]) + url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) @@ -1054,15 +1063,16 @@ async def get_openai_models( if url_idx is None: models = await get_all_models() - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["models"] = list( - filter( - lambda model: model["name"] - in app.state.config.MODEL_FILTER_LIST, - models["models"], - ) - ) + # TODO: Check User Group and Filter Models + # if app.state.config.ENABLE_MODEL_FILTER: + # if user.role == "user": + # models["models"] = list( + # filter( + # lambda model: model["name"] + # in app.state.config.MODEL_FILTER_LIST, + # models["models"], + # ) + # ) return { "data": [ diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 5a4dba62fd..cbea604678 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -11,9 +11,7 @@ from open_webui.apps.webui.models.models import Models from open_webui.config import ( CACHE_DIR, CORS_ALLOW_ORIGIN, - ENABLE_MODEL_FILTER, ENABLE_OPENAI_API, - MODEL_FILTER_LIST, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, OPENAI_API_CONFIGS, @@ -39,6 +37,8 @@ from open_webui.utils.payload import ( ) from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) @@ -61,25 +61,11 @@ app.add_middleware( app.state.config = AppConfig() -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS -app.state.MODELS = {} - - -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - - response = await call_next(request) - return response - @app.get("/config") async def get_config(user=Depends(get_admin_user)): @@ -264,7 +250,7 @@ def merge_models_lists(model_lists): return merged_list -async def get_all_models_raw() -> list: +async def get_all_models_responses() -> list: if not app.state.config.ENABLE_OPENAI_API: return [] @@ -335,22 +321,13 @@ async def get_all_models_raw() -> list: return responses -@overload -async def get_all_models(raw: Literal[True]) -> list: ... - - -@overload -async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... - - -async def get_all_models(raw=False) -> dict[str, list] | list: +async def get_all_models() -> dict[str, list]: log.info("get_all_models()") - if not app.state.config.ENABLE_OPENAI_API: - return [] if raw else {"data": []} - responses = await get_all_models_raw() - if raw: - return responses + if not app.state.config.ENABLE_OPENAI_API: + return {"data": []} + + responses = await get_all_models_responses() def extract_data(response): if response and "data" in response: @@ -360,9 +337,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list: return None models = {"data": merge_models_lists(map(extract_data, responses))} - log.debug(f"models: {models}") - app.state.MODELS = {model["id"]: model for model in models["data"]} return models @@ -370,18 +345,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list: @app.get("/models") @app.get("/models/{url_idx}") async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): + models = { + "data": [], + } + if url_idx is None: models = await get_all_models() - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models["data"] = list( - filter( - lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - models["data"], - ) - ) - return models - return models else: url = app.state.config.OPENAI_API_BASE_URLS[url_idx] key = app.state.config.OPENAI_API_KEYS[url_idx] @@ -389,6 +358,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us headers = {} headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" + if ENABLE_FORWARD_USER_INFO_HEADERS: headers["X-OpenWebUI-User-Name"] = user.name headers["X-OpenWebUI-User-Id"] = user.id @@ -430,8 +400,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us ) ] - return response_data - + models = response_data except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") @@ -445,6 +414,22 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) + if user.role == "user": + # Filter models based on user access control + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + else: + filtered_models.append(model) + models["data"] = filtered_models + + return models + class ConnectionVerificationForm(BaseModel): url: str @@ -492,11 +477,10 @@ async def verify_connection( @app.post("/chat/completions") -@app.post("/chat/completions/{url_idx}") async def generate_chat_completion( form_data: dict, - url_idx: Optional[int] = None, user=Depends(get_verified_user), + bypass_filter: Optional[bool] = False, ): idx = 0 payload = {**form_data} @@ -507,6 +491,7 @@ async def generate_chat_completion( model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) + # Check model info and override the payload if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -515,9 +500,33 @@ async def generate_chat_completion( payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) - model = app.state.MODELS[payload.get("model")] - idx = model["urlIdx"] + # Check if user has access to the model + if user.role == "user" and not has_access( + user.id, type="read", access_control=model_info.access_control + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + # Attemp to get urlIdx from the model + models = await get_all_models() + + # Find the model from the list + model = next( + (model for model in models["data"] if model["id"] == payload.get("model")), + None, + ) + + if model: + idx = model["urlIdx"] + else: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + + # Get the API config for the model api_config = app.state.config.OPENAI_API_CONFIGS.get( app.state.config.OPENAI_API_BASE_URLS[idx], {} ) @@ -526,6 +535,7 @@ async def generate_chat_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + # Add user info to the payload if the model is a pipeline if "pipeline" in model and model.get("pipeline"): payload["user"] = { "name": user.name, @@ -536,8 +546,9 @@ async def generate_chat_completion( url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] - is_o1 = payload["model"].lower().startswith("o1-") + # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" + is_o1 = payload["model"].lower().startswith("o1-") # Change max_completion_tokens to max_tokens (Backward compatible) if "api.openai.com" not in url and not is_o1: if "max_completion_tokens" in payload: diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 8f855473ee..7d92b7350d 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -3,6 +3,7 @@ import os import uuid from typing import Optional, Union +import asyncio import requests from huggingface_hub import snapshot_download @@ -291,7 +292,13 @@ def get_embedding_function( if embedding_engine == "": return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: - func = lambda query: generate_embeddings( + + # Wrapper to run the async generate_embeddings synchronously. + def sync_generate_embeddings(*args, **kwargs): + return asyncio.run(generate_embeddings(*args, **kwargs)) + + # Semantic expectation from the original version (using sync wrapper). + func = lambda query: sync_generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, @@ -469,7 +476,7 @@ def get_model_path(model: str, update_model: bool = False): return model -def generate_openai_batch_embeddings( +async def generate_openai_batch_embeddings( model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" ) -> Optional[list[list[float]]]: try: @@ -492,14 +499,16 @@ def generate_openai_batch_embeddings( return None -def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): +async def generate_embeddings( + engine: str, model: str, text: Union[str, list[str]], **kwargs +): if engine == "ollama": if isinstance(text, list): - embeddings = generate_ollama_batch_embeddings( + embeddings = await generate_ollama_batch_embeddings( GenerateEmbedForm(**{"model": model, "input": text}) ) else: - embeddings = generate_ollama_batch_embeddings( + embeddings = await generate_ollama_batch_embeddings( GenerateEmbedForm(**{"model": model, "input": [text]}) ) return ( @@ -512,9 +521,9 @@ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], ** url = kwargs.get("url", "https://api.openai.com/v1") if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, key, url) + embeddings = await generate_openai_batch_embeddings(model, text, key, url) else: - embeddings = generate_openai_batch_embeddings(model, [text], key, url) + embeddings = await generate_openai_batch_embeddings(model, [text], key, url) return embeddings[0] if isinstance(text, str) else embeddings diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 7564475d43..ae54ab29a4 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -12,6 +12,7 @@ from open_webui.apps.webui.routers import ( chats, folders, configs, + groups, files, functions, memories, @@ -85,7 +86,11 @@ from open_webui.utils.payload import ( from open_webui.utils.tools import get_tools -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) log = logging.getLogger(__name__) @@ -105,6 +110,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL app.state.config.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE + + app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.BANNERS = WEBUI_BANNERS @@ -137,7 +144,6 @@ app.state.config.LDAP_USE_TLS = LDAP_USE_TLS app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE app.state.config.LDAP_CIPHERS = LDAP_CIPHERS -app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} @@ -161,13 +167,15 @@ app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(tools.router, prefix="/tools", tags=["tools"]) -app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) +app.include_router(folders.router, prefix="/folders", tags=["folders"]) + +app.include_router(groups.router, prefix="/groups", tags=["groups"]) +app.include_router(files.router, prefix="/files", tags=["files"]) +app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) -app.include_router(folders.router, prefix="/folders", tags=["folders"]) -app.include_router(files.router, prefix="/files", tags=["files"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) @@ -362,7 +370,7 @@ def get_function_params(function_module, form_data, user, extra_params=None): return params -async def generate_function_chat_completion(form_data, user): +async def generate_function_chat_completion(form_data, user, models: dict = {}): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) @@ -405,7 +413,7 @@ async def generate_function_chat_completion(form_data, user): user, { **extra_params, - "__model__": app.state.MODELS[form_data["model"]], + "__model__": models.get(form_data["model"], None), "__messages__": form_data["messages"], "__files__": files, }, diff --git a/backend/open_webui/apps/webui/models/documents.py b/backend/open_webui/apps/webui/models/documents.py deleted file mode 100644 index 0b96c25744..0000000000 --- a/backend/open_webui/apps/webui/models/documents.py +++ /dev/null @@ -1,157 +0,0 @@ -import json -import logging -import time -from typing import Optional - -from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.env import SRC_LOG_LEVELS -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["MODELS"]) - -#################### -# Documents DB Schema -#################### - - -class Document(Base): - __tablename__ = "document" - - collection_name = Column(String, primary_key=True) - name = Column(String, unique=True) - title = Column(Text) - filename = Column(Text) - content = Column(Text, nullable=True) - user_id = Column(String) - timestamp = Column(BigInteger) - - -class DocumentModel(BaseModel): - model_config = ConfigDict(from_attributes=True) - - collection_name: str - name: str - title: str - filename: str - content: Optional[str] = None - user_id: str - timestamp: int # timestamp in epoch - - -#################### -# Forms -#################### - - -class DocumentResponse(BaseModel): - collection_name: str - name: str - title: str - filename: str - content: Optional[dict] = None - user_id: str - timestamp: int # timestamp in epoch - - -class DocumentUpdateForm(BaseModel): - name: str - title: str - - -class DocumentForm(DocumentUpdateForm): - collection_name: str - filename: str - content: Optional[str] = None - - -class DocumentsTable: - def insert_new_doc( - self, user_id: str, form_data: DocumentForm - ) -> Optional[DocumentModel]: - with get_db() as db: - document = DocumentModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "timestamp": int(time.time()), - } - ) - - try: - result = Document(**document.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: - return None - except Exception: - return None - - def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: - try: - with get_db() as db: - document = db.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None - except Exception: - return None - - def get_docs(self) -> list[DocumentModel]: - with get_db() as db: - return [ - DocumentModel.model_validate(doc) for doc in db.query(Document).all() - ] - - def update_doc_by_name( - self, name: str, form_data: DocumentUpdateForm - ) -> Optional[DocumentModel]: - try: - with get_db() as db: - db.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(form_data.name) - except Exception as e: - log.exception(e) - return None - - def update_doc_content_by_name( - self, name: str, updated: dict - ) -> Optional[DocumentModel]: - try: - doc = self.get_doc_by_name(name) - doc_content = json.loads(doc.content if doc.content else "{}") - doc_content = {**doc_content, **updated} - - with get_db() as db: - db.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(name) - except Exception as e: - log.exception(e) - return None - - def delete_doc_by_name(self, name: str) -> bool: - try: - with get_db() as db: - db.query(Document).filter_by(name=name).delete() - db.commit() - return True - except Exception: - return False - - -Documents = DocumentsTable() diff --git a/backend/open_webui/apps/webui/models/groups.py b/backend/open_webui/apps/webui/models/groups.py new file mode 100644 index 0000000000..e687374ea8 --- /dev/null +++ b/backend/open_webui/apps/webui/models/groups.py @@ -0,0 +1,181 @@ +import json +import logging +import time +from typing import Optional +import uuid + +from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.files import FileMetadataResponse + + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text, JSON + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# UserGroup DB Schema +#################### + + +class Group(Base): + __tablename__ = "group" + + id = Column(Text, unique=True, primary_key=True) + user_id = Column(Text) + + name = Column(Text) + description = Column(Text) + + data = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + + permissions = Column(JSON, nullable=True) + user_ids = Column(JSON, nullable=True) + + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class GroupModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str + user_id: str + + name: str + description: str + + data: Optional[dict] = None + meta: Optional[dict] = None + + permissions: Optional[dict] = None + user_ids: list[str] = [] + + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class GroupResponse(BaseModel): + id: str + user_id: str + name: str + description: str + permissions: Optional[dict] = None + data: Optional[dict] = None + meta: Optional[dict] = None + user_ids: list[str] = [] + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +class GroupForm(BaseModel): + name: str + description: str + + +class GroupUpdateForm(GroupForm): + permissions: Optional[dict] = None + user_ids: Optional[list[str]] = None + admin_ids: Optional[list[str]] = None + + +class GroupTable: + def insert_new_group( + self, user_id: str, form_data: GroupForm + ) -> Optional[GroupModel]: + with get_db() as db: + group = GroupModel( + **{ + **form_data.model_dump(), + "id": str(uuid.uuid4()), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + try: + result = Group(**group.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return GroupModel.model_validate(result) + else: + return None + + except Exception: + return None + + def get_groups(self) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Group).order_by(Group.updated_at.desc()).all() + ] + + def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Group) + .filter(Group.user_ids.contains([user_id])) + .order_by(Group.updated_at.desc()) + .all() + ] + + def get_group_by_id(self, id: str) -> Optional[GroupModel]: + try: + with get_db() as db: + group = db.query(Group).filter_by(id=id).first() + return GroupModel.model_validate(group) if group else None + except Exception: + return None + + def update_group_by_id( + self, id: str, form_data: GroupUpdateForm, overwrite: bool = False + ) -> Optional[GroupModel]: + try: + with get_db() as db: + db.query(Group).filter_by(id=id).update( + { + **form_data.model_dump(exclude_none=True), + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_group_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def delete_group_by_id(self, id: str) -> bool: + try: + with get_db() as db: + db.query(Group).filter_by(id=id).delete() + db.commit() + return True + except Exception: + return False + + def delete_all_groups(self) -> bool: + with get_db() as db: + try: + db.query(Group).delete() + db.commit() + + return True + except Exception: + return False + + +Groups = GroupTable() diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py index 269ad8cc3c..2d0e33f1bb 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/apps/webui/models/knowledge.py @@ -13,6 +13,7 @@ from open_webui.apps.webui.models.files import FileMetadataResponse from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON +from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -34,6 +35,23 @@ class Knowledge(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -50,6 +68,8 @@ class KnowledgeModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None + access_control: Optional[dict] = None + created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -65,6 +85,8 @@ class KnowledgeResponse(BaseModel): description: str data: Optional[dict] = None meta: Optional[dict] = None + + access_control: Optional[dict] = None created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -75,12 +97,7 @@ class KnowledgeForm(BaseModel): name: str description: str data: Optional[dict] = None - - -class KnowledgeUpdateForm(BaseModel): - name: Optional[str] = None - description: Optional[str] = None - data: Optional[dict] = None + access_control: Optional[dict] = None class KnowledgeTable: @@ -110,7 +127,7 @@ class KnowledgeTable: except Exception: return None - def get_knowledge_items(self) -> list[KnowledgeModel]: + def get_knowledge_bases(self) -> list[KnowledgeModel]: with get_db() as db: return [ KnowledgeModel.model_validate(knowledge) @@ -119,6 +136,17 @@ class KnowledgeTable: .all() ] + def get_knowledge_bases_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[KnowledgeModel]: + knowledge_bases = self.get_knowledge_bases() + return [ + knowledge_base + for knowledge_base in knowledge_bases + if knowledge_base.user_id == user_id + or has_access(user_id, permission, knowledge_base.access_control) + ] + def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: try: with get_db() as db: @@ -128,14 +156,32 @@ class KnowledgeTable: return None def update_knowledge_by_id( - self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False + self, id: str, form_data: KnowledgeForm, overwrite: bool = False ) -> Optional[KnowledgeModel]: try: with get_db() as db: knowledge = self.get_knowledge_by_id(id=id) db.query(Knowledge).filter_by(id=id).update( { - **form_data.model_dump(exclude_none=True), + **form_data.model_dump(), + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_knowledge_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def update_knowledge_data_by_id( + self, id: str, data: dict + ) -> Optional[KnowledgeModel]: + try: + with get_db() as db: + knowledge = self.get_knowledge_by_id(id=id) + db.query(Knowledge).filter_by(id=id).update( + { + "data": data, "updated_at": int(time.time()), } ) diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index 9bdffb9bcc..46591bd953 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -4,8 +4,19 @@ from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.groups import Groups + + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text + +from sqlalchemy import or_, and_, func +from sqlalchemy.dialects import postgresql, sqlite +from sqlalchemy import BigInteger, Column, Text, JSON, Boolean + + +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -67,6 +78,25 @@ class Model(Base): Holds a JSON encoded blob of metadata, see `ModelMeta`. """ + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + + is_active = Column(Boolean, default=True) + updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -80,6 +110,9 @@ class ModelModel(BaseModel): params: ModelParams meta: ModelMeta + access_control: Optional[dict] = None + + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -93,8 +126,16 @@ class ModelModel(BaseModel): class ModelResponse(BaseModel): id: str + user_id: str + base_model_id: Optional[str] = None + name: str + params: ModelParams meta: ModelMeta + + access_control: Optional[dict] = None + + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -105,6 +146,8 @@ class ModelForm(BaseModel): name: str meta: ModelMeta params: ModelParams + access_control: Optional[dict] = None + is_active: bool = True class ModelsTable: @@ -138,6 +181,31 @@ class ModelsTable: with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] + def get_models(self) -> list[ModelModel]: + with get_db() as db: + return [ + ModelModel.model_validate(model) + for model in db.query(Model).filter(Model.base_model_id != None).all() + ] + + def get_base_models(self) -> list[ModelModel]: + with get_db() as db: + return [ + ModelModel.model_validate(model) + for model in db.query(Model).filter(Model.base_model_id == None).all() + ] + + def get_models_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ModelModel]: + models = self.get_all_models() + return [ + model + for model in models + if model.user_id == user_id + or has_access(user_id, permission, model.access_control) + ] + def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: @@ -146,6 +214,23 @@ class ModelsTable: except Exception: return None + def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: + with get_db() as db: + try: + is_active = db.query(Model).filter_by(id=id).first().is_active + + db.query(Model).filter_by(id=id).update( + { + "is_active": not is_active, + "updated_at": int(time.time()), + } + ) + db.commit() + + return self.get_model_by_id(id) + except Exception: + return None + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: with get_db() as db: @@ -153,7 +238,7 @@ class ModelsTable: result = ( db.query(Model) .filter_by(id=id) - .update(model.model_dump(exclude={"id"}, exclude_none=True)) + .update(model.model_dump(exclude={"id"})) ) db.commit() diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/apps/webui/models/prompts.py index 6b98e5c535..ea4a229f79 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/apps/webui/models/prompts.py @@ -2,8 +2,12 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.groups import Groups + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, JSON + +from open_webui.utils.access_control import has_access #################### # Prompts DB Schema @@ -19,6 +23,23 @@ class Prompt(Base): content = Column(Text) timestamp = Column(BigInteger) + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + class PromptModel(BaseModel): command: str @@ -27,6 +48,7 @@ class PromptModel(BaseModel): content: str timestamp: int # timestamp in epoch + access_control: Optional[dict] = None model_config = ConfigDict(from_attributes=True) @@ -39,6 +61,7 @@ class PromptForm(BaseModel): command: str title: str content: str + access_control: Optional[dict] = None class PromptsTable: @@ -48,16 +71,14 @@ class PromptsTable: prompt = PromptModel( **{ "user_id": user_id, - "command": form_data.command, - "title": form_data.title, - "content": form_data.content, + **form_data.model_dump(), "timestamp": int(time.time()), } ) try: with get_db() as db: - result = Prompt(**prompt.dict()) + result = Prompt(**prompt.model_dump()) db.add(result) db.commit() db.refresh(result) @@ -82,6 +103,18 @@ class PromptsTable: PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() ] + def get_prompts_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[PromptModel]: + prompts = self.get_prompts() + + return [ + prompt + for prompt in prompts + if prompt.user_id == user_id + or has_access(user_id, permission, prompt.access_control) + ] + def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: @@ -90,6 +123,7 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content + prompt.access_control = form_data.access_control prompt.timestamp = int(time.time()) db.commit() return PromptModel.model_validate(prompt) diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/apps/webui/models/tools.py index e06f83452b..63570bee61 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/apps/webui/models/tools.py @@ -6,7 +6,10 @@ from open_webui.apps.webui.internal.db import Base, JSONField, get_db from open_webui.apps.webui.models.users import Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, JSON + +from open_webui.utils.access_control import has_access + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -26,6 +29,24 @@ class Tool(Base): specs = Column(JSONField) meta = Column(JSONField) valves = Column(JSONField) + + access_control = Column(JSON, nullable=True) # Controls data access levels. + # Defines access control rules for this entry. + # - `None`: Public access, available to all users with the "user" role. + # - `{}`: Private access, restricted exclusively to the owner. + # - Custom permissions: Specific access control for reading and writing; + # Can specify group or user-level restrictions: + # { + # "read": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # }, + # "write": { + # "group_ids": ["group_id1", "group_id2"], + # "user_ids": ["user_id1", "user_id2"] + # } + # } + updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -42,6 +63,8 @@ class ToolModel(BaseModel): content: str specs: list[dict] meta: ToolMeta + access_control: Optional[dict] = None + updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -58,6 +81,7 @@ class ToolResponse(BaseModel): user_id: str name: str meta: ToolMeta + access_control: Optional[dict] = None updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -67,6 +91,7 @@ class ToolForm(BaseModel): name: str content: str meta: ToolMeta + access_control: Optional[dict] = None class ToolValves(BaseModel): @@ -113,6 +138,18 @@ class ToolsTable: with get_db() as db: return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + def get_tools_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ToolModel]: + tools = self.get_tools() + + return [ + tool + for tool in tools + if tool.user_id == user_id + or has_access(user_id, permission, tool.access_control) + ] + def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index feea350cc7..d3592f03bb 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -40,10 +40,12 @@ from open_webui.utils.utils import ( get_password_hash, ) from open_webui.utils.webhook import post_webhook +from open_webui.utils.access_control import get_permissions + from typing import Optional, List -from ldap3 import Server, Connection, ALL, Tls from ssl import CERT_REQUIRED, PROTOCOL_TLS +from ldap3 import Server, Connection, ALL, Tls from ldap3.utils.conv import escape_filter_chars router = APIRouter() @@ -58,6 +60,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) class SessionUserResponse(Token, UserResponse): expires_at: Optional[int] = None + permissions: Optional[dict] = None @router.get("/", response_model=SessionUserResponse) @@ -90,6 +93,10 @@ async def get_session_user( secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -99,6 +106,7 @@ async def get_session_user( "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } @@ -163,40 +171,67 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE - LDAP_CIPHERS = request.app.state.config.LDAP_CIPHERS if request.app.state.config.LDAP_CIPHERS else 'ALL' + LDAP_CIPHERS = ( + request.app.state.config.LDAP_CIPHERS + if request.app.state.config.LDAP_CIPHERS + else "ALL" + ) if not ENABLE_LDAP: raise HTTPException(400, detail="LDAP authentication is not enabled") try: - tls = Tls(validate=CERT_REQUIRED, version=PROTOCOL_TLS, ca_certs_file=LDAP_CA_CERT_FILE, ciphers=LDAP_CIPHERS) + tls = Tls( + validate=CERT_REQUIRED, + version=PROTOCOL_TLS, + ca_certs_file=LDAP_CA_CERT_FILE, + ciphers=LDAP_CIPHERS, + ) except Exception as e: log.error(f"An error occurred on TLS: {str(e)}") raise HTTPException(400, detail=str(e)) try: - server = Server(host=LDAP_SERVER_HOST, port=LDAP_SERVER_PORT, get_info=ALL, use_ssl=LDAP_USE_TLS, tls=tls) - connection_app = Connection(server, LDAP_APP_DN, LDAP_APP_PASSWORD, auto_bind='NONE', authentication='SIMPLE') + server = Server( + host=LDAP_SERVER_HOST, + port=LDAP_SERVER_PORT, + get_info=ALL, + use_ssl=LDAP_USE_TLS, + tls=tls, + ) + connection_app = Connection( + server, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + auto_bind="NONE", + authentication="SIMPLE", + ) if not connection_app.bind(): raise HTTPException(400, detail="Application account bind failed") search_success = connection_app.search( search_base=LDAP_SEARCH_BASE, - search_filter=f'(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})', - attributes=[f'{LDAP_ATTRIBUTE_FOR_USERNAME}', 'mail', 'cn'] + search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", + attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"], ) if not search_success: raise HTTPException(400, detail="User not found in the LDAP server") entry = connection_app.entries[0] - username = str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}']).lower() - mail = str(entry['mail']) - cn = str(entry['cn']) + username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() + mail = str(entry["mail"]) + cn = str(entry["cn"]) user_dn = entry.entry_dn if username == form_data.user.lower(): - connection_user = Connection(server, user_dn, form_data.password, auto_bind='NONE', authentication='SIMPLE') + connection_user = Connection( + server, + user_dn, + form_data.password, + auto_bind="NONE", + authentication="SIMPLE", + ) if not connection_user.bind(): raise HTTPException(400, f"Authentication failed for {form_data.user}") @@ -205,14 +240,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): try: hashed = get_password_hash(form_data.password) - user = Auths.insert_new_auth( - mail, - hashed, - cn - ) + user = Auths.insert_new_auth(mail, hashed, cn) if not user: - raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) + raise HTTPException( + 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR + ) except HTTPException: raise @@ -224,7 +257,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), + expires_delta=parse_duration( + request.app.state.config.JWT_EXPIRES_IN + ), ) # Set the cookie token @@ -246,7 +281,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) else: - raise HTTPException(400, f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}") + raise HTTPException( + 400, + f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}", + ) except Exception as e: raise HTTPException(400, detail=str(e)) @@ -325,6 +363,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm): secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -334,6 +376,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) @@ -426,6 +469,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm): }, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -435,6 +482,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) @@ -583,19 +631,18 @@ class LdapServerConfig(BaseModel): label: str host: str port: Optional[int] = None - attribute_for_username: str = 'uid' + attribute_for_username: str = "uid" app_dn: str app_dn_password: str search_base: str - search_filters: str = '' + search_filters: str = "" use_tls: bool = True certificate_path: Optional[str] = None - ciphers: Optional[str] = 'ALL' + ciphers: Optional[str] = "ALL" + @router.get("/admin/config/ldap/server", response_model=LdapServerConfig) -async def get_ldap_server( - request: Request, user=Depends(get_admin_user) -): +async def get_ldap_server(request: Request, user=Depends(get_admin_user)): return { "label": request.app.state.config.LDAP_SERVER_LABEL, "host": request.app.state.config.LDAP_SERVER_HOST, @@ -607,26 +654,38 @@ async def get_ldap_server( "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "use_tls": request.app.state.config.LDAP_USE_TLS, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, - "ciphers": request.app.state.config.LDAP_CIPHERS + "ciphers": request.app.state.config.LDAP_CIPHERS, } + @router.post("/admin/config/ldap/server") async def update_ldap_server( request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user) ): - required_fields = ['label', 'host', 'attribute_for_username', 'app_dn', 'app_dn_password', 'search_base'] + required_fields = [ + "label", + "host", + "attribute_for_username", + "app_dn", + "app_dn_password", + "search_base", + ] for key in required_fields: value = getattr(form_data, key) if not value: raise HTTPException(400, detail=f"Required field {key} is empty") if form_data.use_tls and not form_data.certificate_path: - raise HTTPException(400, detail="TLS is enabled but certificate file path is missing") + raise HTTPException( + 400, detail="TLS is enabled but certificate file path is missing" + ) request.app.state.config.LDAP_SERVER_LABEL = form_data.label request.app.state.config.LDAP_SERVER_HOST = form_data.host request.app.state.config.LDAP_SERVER_PORT = form_data.port - request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = form_data.attribute_for_username + request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = ( + form_data.attribute_for_username + ) request.app.state.config.LDAP_APP_DN = form_data.app_dn request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base @@ -646,18 +705,23 @@ async def update_ldap_server( "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "use_tls": request.app.state.config.LDAP_USE_TLS, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, - "ciphers": request.app.state.config.LDAP_CIPHERS + "ciphers": request.app.state.config.LDAP_CIPHERS, } + @router.get("/admin/config/ldap") async def get_ldap_config(request: Request, user=Depends(get_admin_user)): return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + class LdapConfigForm(BaseModel): enable_ldap: Optional[bool] = None + @router.post("/admin/config/ldap") -async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)): +async def update_ldap_config( + request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user) +): request.app.state.config.ENABLE_LDAP = form_data.enable_ldap return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index b149b2eb48..db95337d53 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -17,7 +17,10 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel + + from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_permission log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -50,9 +53,10 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): - if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get( - "chat", {} - ).get("deletion", {}): + + if user.role == "user" and not has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified return result else: - if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get( - "deletion", {} + if not has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/apps/webui/routers/groups.py b/backend/open_webui/apps/webui/routers/groups.py new file mode 100644 index 0000000000..59d7d0052b --- /dev/null +++ b/backend/open_webui/apps/webui/routers/groups.py @@ -0,0 +1,120 @@ +import os +from pathlib import Path +from typing import Optional + +from open_webui.apps.webui.models.groups import ( + Groups, + GroupForm, + GroupUpdateForm, + GroupResponse, +) + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.utils.utils import get_admin_user, get_verified_user + +router = APIRouter() + +############################ +# GetFunctions +############################ + + +@router.get("/", response_model=list[GroupResponse]) +async def get_groups(user=Depends(get_verified_user)): + if user.role == "admin": + return Groups.get_groups() + else: + return Groups.get_groups_by_member_id(user.id) + + +############################ +# CreateNewGroup +############################ + + +@router.post("/create", response_model=Optional[GroupResponse]) +async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)): + try: + group = Groups.insert_new_group(user.id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# GetGroupById +############################ + + +@router.get("/id/{id}", response_model=Optional[GroupResponse]) +async def get_group_by_id(id: str, user=Depends(get_admin_user)): + group = Groups.get_group_by_id(id) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateGroupById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[GroupResponse]) +async def update_group_by_id( + id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) +): + try: + group = Groups.update_group_by_id(id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteGroupById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_group_by_id(id: str, user=Depends(get_admin_user)): + try: + result = Groups.delete_group_by_id(id) + if result: + return result + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index 1b5381a745..1ffadeac23 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -6,7 +6,6 @@ import logging from open_webui.apps.webui.models.knowledge import ( Knowledges, - KnowledgeUpdateForm, KnowledgeForm, KnowledgeResponse, ) @@ -17,6 +16,9 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access + + from open_webui.env import SRC_LOG_LEVELS @@ -26,64 +28,98 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) router = APIRouter() ############################ -# GetKnowledgeItems +# getKnowledgeBases ############################ -@router.get( - "/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]] -) -async def get_knowledge_items( - id: Optional[str] = None, user=Depends(get_verified_user) -): - if id: - knowledge = Knowledges.get_knowledge_by_id(id=id) +@router.get("/", response_model=list[KnowledgeResponse]) +async def get_knowledge(user=Depends(get_verified_user)): + knowledge_bases = [] - if knowledge: - return knowledge - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) + if user.role == "admin": + knowledge_bases = Knowledges.get_knowledge_bases() else: - knowledge_bases = [] + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read") - for knowledge in Knowledges.get_knowledge_items(): - - files = [] - if knowledge.data: - files = Files.get_file_metadatas_by_ids( - knowledge.data.get("file_ids", []) - ) - - # Check if all files exist - if len(files) != len(knowledge.data.get("file_ids", [])): - missing_files = list( - set(knowledge.data.get("file_ids", [])) - - set([file.id for file in files]) - ) - if missing_files: - data = knowledge.data or {} - file_ids = data.get("file_ids", []) - - for missing_file in missing_files: - file_ids.remove(missing_file) - - data["file_ids"] = file_ids - Knowledges.update_knowledge_by_id( - id=knowledge.id, form_data=KnowledgeUpdateForm(data=data) - ) - - files = Files.get_file_metadatas_by_ids(file_ids) - - knowledge_bases.append( - KnowledgeResponse( - **knowledge.model_dump(), - files=files, - ) + # Get files for each knowledge base + for knowledge_base in knowledge_bases: + files = [] + if knowledge_base.data: + files = Files.get_file_metadatas_by_ids( + knowledge_base.data.get("file_ids", []) ) - return knowledge_bases + + # Check if all files exist + if len(files) != len(knowledge_base.data.get("file_ids", [])): + missing_files = list( + set(knowledge_base.data.get("file_ids", [])) + - set([file.id for file in files]) + ) + if missing_files: + data = knowledge_base.data or {} + file_ids = data.get("file_ids", []) + + for missing_file in missing_files: + file_ids.remove(missing_file) + + data["file_ids"] = file_ids + Knowledges.update_knowledge_data_by_id( + id=knowledge_base.id, data=data + ) + + files = Files.get_file_metadatas_by_ids(file_ids) + + knowledge_base = KnowledgeResponse( + **knowledge_base.model_dump(), + files=files, + ) + + return knowledge_bases + + +@router.get("/list", response_model=list[KnowledgeResponse]) +async def get_knowledge_list(user=Depends(get_verified_user)): + knowledge_bases = [] + + if user.role == "admin": + knowledge_bases = Knowledges.get_knowledge_bases() + else: + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write") + + # Get files for each knowledge base + for knowledge_base in knowledge_bases: + files = [] + if knowledge_base.data: + files = Files.get_file_metadatas_by_ids( + knowledge_base.data.get("file_ids", []) + ) + + # Check if all files exist + if len(files) != len(knowledge_base.data.get("file_ids", [])): + missing_files = list( + set(knowledge_base.data.get("file_ids", [])) + - set([file.id for file in files]) + ) + if missing_files: + data = knowledge_base.data or {} + file_ids = data.get("file_ids", []) + + for missing_file in missing_files: + file_ids.remove(missing_file) + + data["file_ids"] = file_ids + Knowledges.update_knowledge_data_by_id( + id=knowledge_base.id, data=data + ) + + files = Files.get_file_metadatas_by_ids(file_ids) + + knowledge_base = KnowledgeResponse( + **knowledge_base.model_dump(), + files=files, + ) + + return knowledge_bases ############################ @@ -92,7 +128,9 @@ async def get_knowledge_items( @router.post("/create", response_model=Optional[KnowledgeResponse]) -async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)): +async def create_new_knowledge( + form_data: KnowledgeForm, user=Depends(get_verified_user) +): knowledge = Knowledges.insert_new_knowledge(user.id, form_data) if knowledge: @@ -118,13 +156,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): knowledge = Knowledges.get_knowledge_by_id(id=id) if knowledge: - file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] - files = Files.get_files_by_ids(file_ids) - return KnowledgeFilesResponse( - **knowledge.model_dump(), - files=files, - ) + if ( + user.role == "admin" + or knowledge.user_id == user.id + or has_access(user.id, "read", knowledge.access_control) + ): + + file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -140,11 +185,23 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse]) async def update_knowledge_by_id( id: str, - form_data: KnowledgeUpdateForm, - user=Depends(get_admin_user), + form_data: KnowledgeForm, + user=Depends(get_verified_user), ): - knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] files = Files.get_files_by_ids(file_ids) @@ -173,9 +230,22 @@ class KnowledgeFileIdForm(BaseModel): def add_file_to_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -206,9 +276,7 @@ def add_file_to_knowledge_by_id( file_ids.append(form_data.file_id) data["file_ids"] = file_ids - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data=data) - ) + knowledge = Knowledges.update_knowledge_data_by_id(id=id.id, data=data) if knowledge: files = Files.get_files_by_ids(file_ids) @@ -238,9 +306,21 @@ def add_file_to_knowledge_by_id( def update_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -288,9 +368,21 @@ def update_file_from_knowledge_by_id( def remove_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + file = Files.get_file_by_id(form_data.file_id) if not file: raise HTTPException( @@ -318,9 +410,7 @@ def remove_file_from_knowledge_by_id( file_ids.remove(form_data.file_id) data["file_ids"] = file_ids - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data=data) - ) + knowledge = Knowledges.update_knowledge_data_by_id(id=id.id, data=data) if knowledge: files = Files.get_files_by_ids(file_ids) @@ -346,32 +436,26 @@ def remove_file_from_knowledge_by_id( ) -############################ -# ResetKnowledgeById -############################ - - -@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)): - try: - VECTOR_DB_CLIENT.delete_collection(collection_name=id) - except Exception as e: - log.debug(e) - pass - - knowledge = Knowledges.update_knowledge_by_id( - id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []}) - ) - return knowledge - - ############################ # DeleteKnowledgeById ############################ @router.delete("/{id}/delete", response_model=bool) -async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)): +async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + try: VECTOR_DB_CLIENT.delete_collection(collection_name=id) except Exception as e: @@ -379,3 +463,34 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)): pass result = Knowledges.delete_knowledge_by_id(id=id) return result + + +############################ +# ResetKnowledgeById +############################ + + +@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) +async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + try: + VECTOR_DB_CLIENT.delete_collection(collection_name=id) + except Exception as e: + log.debug(e) + pass + + knowledge = Knowledges.update_knowledge_data_by_id(id=id.id, data={"file_ids": []}) + + return knowledge diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index a5cb2395ec..8d6d950966 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -8,49 +8,58 @@ from open_webui.apps.webui.models.models import ( ) from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status + + from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access + router = APIRouter() + ########################### -# getModels +# GetModels ########################### @router.get("/", response_model=list[ModelResponse]) async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): - if id: - model = Models.get_model_by_id(id) - if model: - return [model] - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) + if user.role == "admin": + return Models.get_models() else: - return Models.get_all_models() + return Models.get_models_by_user_id(user.id) + + +########################### +# GetBaseModels +########################### + + +@router.get("/base", response_model=list[ModelResponse]) +async def get_base_models(user=Depends(get_admin_user)): + return Models.get_base_models() ############################ -# AddNewModel +# CreateNewModel ############################ -@router.post("/add", response_model=Optional[ModelModel]) -async def add_new_model( - request: Request, +@router.post("/create", response_model=Optional[ModelModel]) +async def create_new_model( form_data: ModelForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): - if form_data.id in request.app.state.MODELS: + + model = Models.get_model_by_id(form_data.id) + if model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.MODEL_ID_TAKEN, ) + else: model = Models.insert_new_model(form_data, user.id) - if model: return model else: @@ -60,37 +69,84 @@ async def add_new_model( ) +########################### +# GetModelById +########################### + + +@router.get("/id/{id}", response_model=Optional[ModelResponse]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "read", model.access_control) + ): + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# ToggelModelById +############################ + + +@router.post("/id/{id}/toggle", response_model=Optional[ModelResponse]) +async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "write", model.access_control) + ): + model = Models.toggle_model_by_id(id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateModelById ############################ -@router.post("/update", response_model=Optional[ModelModel]) +@router.post("/id/{id}/update", response_model=Optional[ModelModel]) async def update_model_by_id( - request: Request, id: str, form_data: ModelForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): model = Models.get_model_by_id(id) - if model: - model = Models.update_model_by_id(id, form_data) - return model - else: - if form_data.id in request.app.state.MODELS: - model = Models.insert_new_model(form_data, user.id) - if model: - return model - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) + + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + model = Models.update_model_by_id(id, form_data) + return model ############################ @@ -98,7 +154,20 @@ async def update_model_by_id( ############################ -@router.delete("/delete", response_model=bool) -async def delete_model_by_id(id: str, user=Depends(get_admin_user)): +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if not model: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if model.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + result = Models.delete_model_by_id(id) return result diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py index 593c643b97..ec65932910 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/apps/webui/routers/prompts.py @@ -4,6 +4,7 @@ from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompt from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, status from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access router = APIRouter() @@ -14,7 +15,22 @@ router = APIRouter() @router.get("/", response_model=list[PromptModel]) async def get_prompts(user=Depends(get_verified_user)): - return Prompts.get_prompts() + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "read") + + return prompts + + +@router.get("/list", response_model=list[PromptModel]) +async def get_prompt_list(user=Depends(get_verified_user)): + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "write") + + return prompts ############################ @@ -23,7 +39,7 @@ async def get_prompts(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): +async def create_new_prompt(form_data: PromptForm, user=Depends(get_verified_user)): prompt = Prompts.get_prompt_by_command(form_data.command) if prompt is None: prompt = Prompts.insert_new_prompt(user.id, form_data) @@ -50,7 +66,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): prompt = Prompts.get_prompt_by_command(f"/{command}") if prompt: - return prompt + if ( + user.role == "admin" + or prompt.user_id == user.id + or has_access(user.id, "read", prompt.access_control) + ): + return prompt else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -67,8 +88,21 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): async def update_prompt_by_command( command: str, form_data: PromptForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) if prompt: return prompt @@ -85,6 +119,19 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): +async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + result = Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index d1ad89deae..c34e7681b9 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -3,48 +3,66 @@ from pathlib import Path from typing import Optional from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools -from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports +from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports from open_webui.config import CACHE_DIR, DATA_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access router = APIRouter() ############################ -# GetToolkits +# GetTools ############################ @router.get("/", response_model=list[ToolResponse]) -async def get_toolkits(user=Depends(get_verified_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +async def get_tools(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "read") + return tools ############################ -# ExportToolKits +# GetToolList +############################ + + +@router.get("/list", response_model=list[ToolResponse]) +async def get_tool_list(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "write") + return tools + + +############################ +# ExportTools ############################ @router.get("/export", response_model=list[ToolModel]) -async def get_toolkits(user=Depends(get_admin_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +async def export_tools(user=Depends(get_admin_user)): + tools = Tools.get_tools() + return tools ############################ -# CreateNewToolKit +# CreateNewTools ############################ @router.post("/create", response_model=Optional[ToolResponse]) -async def create_new_toolkit( +async def create_new_tools( request: Request, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): if not form_data.id.isidentifier(): raise HTTPException( @@ -54,30 +72,30 @@ async def create_new_toolkit( form_data.id = form_data.id.lower() - toolkit = Tools.get_tool_by_id(form_data.id) - if toolkit is None: + tools = Tools.get_tool_by_id(form_data.id) + if tools is None: try: form_data.content = replace_imports(form_data.content) - toolkit_module, frontmatter = load_toolkit_module_by_id( + tools_module, frontmatter = load_tools_module_by_id( form_data.id, content=form_data.content ) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS - TOOLS[form_data.id] = toolkit_module + TOOLS[form_data.id] = tools_module specs = get_tools_specs(TOOLS[form_data.id]) - toolkit = Tools.insert_new_tool(user.id, form_data, specs) + tools = Tools.insert_new_tool(user.id, form_data, specs) tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) - if toolkit: - return toolkit + if tools: + return tools else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"), + detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), ) except Exception as e: print(e) @@ -93,16 +111,21 @@ async def create_new_toolkit( ############################ -# GetToolkitById +# GetToolsById ############################ @router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): - toolkit = Tools.get_tool_by_id(id) +async def get_tools_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) - if toolkit: - return toolkit + if tools: + if ( + user.role == "admin" + or tools.user_id == user.id + or has_access(user.id, "read", tools.access_control) + ): + return tools else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -111,26 +134,39 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ############################ -# UpdateToolkitById +# UpdateToolsById ############################ @router.post("/id/{id}/update", response_model=Optional[ToolModel]) -async def update_toolkit_by_id( +async def update_tools_by_id( request: Request, id: str, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): + tools = Tools.get_tool_by_id(id) + if not tools: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if tools.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + try: form_data.content = replace_imports(form_data.content) - toolkit_module, frontmatter = load_toolkit_module_by_id( + tools_module, frontmatter = load_tools_module_by_id( id, content=form_data.content ) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS - TOOLS[id] = toolkit_module + TOOLS[id] = tools_module specs = get_tools_specs(TOOLS[id]) @@ -140,14 +176,14 @@ async def update_toolkit_by_id( } print(updated) - toolkit = Tools.update_tool_by_id(id, updated) + tools = Tools.update_tool_by_id(id, updated) - if toolkit: - return toolkit + if tools: + return tools else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), + detail=ERROR_MESSAGES.DEFAULT("Error updating tools"), ) except Exception as e: @@ -158,14 +194,28 @@ async def update_toolkit_by_id( ############################ -# DeleteToolkitById +# DeleteToolsById ############################ @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): - result = Tools.delete_tool_by_id(id) +async def delete_tools_by_id( + request: Request, id: str, user=Depends(get_verified_user) +): + tools = Tools.get_tool_by_id(id) + if not tools: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + if tools.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + result = Tools.delete_tool_by_id(id) if result: TOOLS = request.app.state.TOOLS if id in TOOLS: @@ -180,9 +230,9 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): - toolkit = Tools.get_tool_by_id(id) - if toolkit: +async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) + if tools: try: valves = Tools.get_tool_valves_by_id(id) return valves @@ -204,19 +254,19 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) -async def get_toolkit_valves_spec_by_id( - request: Request, id: str, user=Depends(get_admin_user) +async def get_tools_valves_spec_by_id( + request: Request, id: str, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "Valves"): - Valves = toolkit_module.Valves + if hasattr(tools_module, "Valves"): + Valves = tools_module.Valves return Valves.schema() return None else: @@ -232,19 +282,19 @@ async def get_toolkit_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) -async def update_toolkit_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_admin_user) +async def update_tools_valves_by_id( + request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "Valves"): - Valves = toolkit_module.Valves + if hasattr(tools_module, "Valves"): + Valves = tools_module.Valves try: form_data = {k: v for k, v in form_data.items() if v is not None} @@ -276,9 +326,9 @@ async def update_toolkit_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)): - toolkit = Tools.get_tool_by_id(id) - if toolkit: +async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): + tools = Tools.get_tool_by_id(id) + if tools: try: user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) return user_valves @@ -295,19 +345,19 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user) @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) -async def get_toolkit_user_valves_spec_by_id( +async def get_tools_user_valves_spec_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) - if toolkit: + tools = Tools.get_tool_by_id(id) + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "UserValves"): - UserValves = toolkit_module.UserValves + if hasattr(tools_module, "UserValves"): + UserValves = tools_module.UserValves return UserValves.schema() return None else: @@ -318,20 +368,20 @@ async def get_toolkit_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) -async def update_toolkit_user_valves_by_id( +async def update_tools_user_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - toolkit = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id) - if toolkit: + if tools: if id in request.app.state.TOOLS: - toolkit_module = request.app.state.TOOLS[id] + tools_module = request.app.state.TOOLS[id] else: - toolkit_module, _ = load_toolkit_module_by_id(id) - request.app.state.TOOLS[id] = toolkit_module + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module - if hasattr(toolkit_module, "UserValves"): - UserValves = toolkit_module.UserValves + if hasattr(tools_module, "UserValves"): + UserValves = tools_module.UserValves try: form_data = {k: v for k, v in form_data.items() if v is not None} diff --git a/backend/open_webui/apps/webui/routers/users.py b/backend/open_webui/apps/webui/routers/users.py index abc540efa8..b6b91a5c30 100644 --- a/backend/open_webui/apps/webui/routers/users.py +++ b/backend/open_webui/apps/webui/routers/users.py @@ -31,21 +31,58 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) return Users.get_users(skip, limit) +############################ +# User Groups +############################ + + +@router.get("/groups") +async def get_user_groups(user=Depends(get_verified_user)): + return Users.get_user_groups(user.id) + + ############################ # User Permissions ############################ -@router.get("/permissions/user") +@router.get("/permissions") +async def get_user_permissisions(user=Depends(get_verified_user)): + return Users.get_user_groups(user.id) + + +############################ +# User Default Permissions +############################ +class WorkspacePermissions(BaseModel): + models: bool + knowledge: bool + prompts: bool + tools: bool + + +class ChatPermissions(BaseModel): + file_upload: bool + delete: bool + edit: bool + temporary: bool + + +class UserPermissions(BaseModel): + workspace: WorkspacePermissions + chat: ChatPermissions + + +@router.get("/default/permissions") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): return request.app.state.config.USER_PERMISSIONS -@router.post("/permissions/user") +@router.post("/default/permissions") async def update_user_permissions( - request: Request, form_data: dict, user=Depends(get_admin_user) + request: Request, form_data: UserPermissions, user=Depends(get_admin_user) ): - request.app.state.config.USER_PERMISSIONS = form_data + request.app.state.config.USER_PERMISSIONS = form_data.model_dump() return request.app.state.config.USER_PERMISSIONS diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/apps/webui/utils.py index 51d3796568..6bfddd0728 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/apps/webui/utils.py @@ -63,7 +63,7 @@ def replace_imports(content): return content -def load_toolkit_module_by_id(toolkit_id, content=None): +def load_tools_module_by_id(toolkit_id, content=None): if content is None: tool = Tools.get_tool_by_id(toolkit_id) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 83efa89faf..cbe85ce6d8 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -739,12 +739,36 @@ DEFAULT_USER_ROLE = PersistentConfig( os.getenv("DEFAULT_USER_ROLE", "pending"), ) -USER_PERMISSIONS_CHAT_DELETION = ( - os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" + +USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() + == "true" ) -USER_PERMISSIONS_CHAT_EDITING = ( - os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true" +USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower() + == "true" +) + +USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" +) + +USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( + os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_DELETE = ( + os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_EDIT = ( + os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true" ) USER_PERMISSIONS_CHAT_TEMPORARY = ( @@ -753,13 +777,20 @@ USER_PERMISSIONS_CHAT_TEMPORARY = ( USER_PERMISSIONS = PersistentConfig( "USER_PERMISSIONS", - "ui.user_permissions", + "user.permissions", { + "workspace": { + "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, + "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, + "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, + "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, + }, "chat": { - "deletion": USER_PERMISSIONS_CHAT_DELETION, - "editing": USER_PERMISSIONS_CHAT_EDITING, + "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, + "delete": USER_PERMISSIONS_CHAT_DELETE, + "edit": USER_PERMISSIONS_CHAT_EDIT, "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, - } + }, }, ) @@ -785,18 +816,6 @@ DEFAULT_ARENA_MODEL = { }, } -ENABLE_MODEL_FILTER = PersistentConfig( - "ENABLE_MODEL_FILTER", - "model_filter.enable", - os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", -) -MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = PersistentConfig( - "MODEL_FILTER_LIST", - "model_filter.list", - [model.strip() for model in MODEL_FILTER_LIST.split(";")], -) - WEBHOOK_URL = PersistentConfig( "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") ) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index dd7504d5f2..7fdd45c971 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -11,6 +11,7 @@ import random from contextlib import asynccontextmanager from typing import Optional +from aiocache import cached import aiohttp import requests from fastapi import ( @@ -45,6 +46,7 @@ from open_webui.apps.openai.main import ( app as openai_app, generate_chat_completion as generate_openai_chat_completion, get_all_models as get_openai_models, + get_all_models_responses as get_openai_models_responses, ) from open_webui.apps.retrieval.main import app as retrieval_app from open_webui.apps.retrieval.utils import get_rag_context, rag_template @@ -70,13 +72,11 @@ from open_webui.config import ( DEFAULT_LOCALE, ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT, - ENABLE_MODEL_FILTER, ENABLE_OLLAMA_API, ENABLE_OPENAI_API, ENABLE_TAGS_GENERATION, ENV, FRONTEND_BUILD_DIR, - MODEL_FILTER_LIST, OAUTH_PROVIDERS, ENABLE_SEARCH_QUERY, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, @@ -135,6 +135,7 @@ from open_webui.utils.utils import ( get_http_authorization_cred, get_verified_user, ) +from open_webui.utils.access_control import has_access if SAFE_MODE: print("SAFE MODE ENABLED") @@ -183,7 +184,10 @@ async def lifespan(app: FastAPI): app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, + lifespan=lifespan, ) app.state.config = AppConfig() @@ -191,27 +195,26 @@ app.state.config = AppConfig() app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API -app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL + app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE -app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE + app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION +app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE + + +app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE ) -app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) -app.state.MODELS = {} - - ################################## # # ChatCompletion Middleware @@ -219,26 +222,6 @@ app.state.MODELS = {} ################################## -def get_task_model_id(default_model_id): - # Set the task model - task_model_id = default_model_id - # Check if the user has a custom task model and use that model - if app.state.MODELS[task_model_id]["owned_by"] == "ollama": - if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL - else: - if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - - return task_model_id - - def get_filter_function_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) @@ -368,8 +351,24 @@ async def get_content_from_response(response) -> Optional[str]: return content +def get_task_model_id( + default_model_id: str, task_model: str, task_model_external: str, models +) -> str: + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if models[task_model_id]["owned_by"] == "ollama": + if task_model and task_model in models: + task_model_id = task_model + else: + if task_model_external and task_model_external in models: + task_model_id = task_model_external + + return task_model_id + + async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict + body: dict, user: UserModel, models, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) @@ -383,14 +382,19 @@ async def chat_completion_tools_handler( contexts = [] citations = [] - task_model_id = get_task_model_id(body["model"]) + task_model_id = get_task_model_id( + body["model"], + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) tools = get_tools( webui_app, tool_ids, user, { **extra_params, - "__model__": app.state.MODELS[task_model_id], + "__model__": models[task_model_id], "__messages__": body["messages"], "__files__": metadata.get("files", []), }, @@ -414,7 +418,7 @@ async def chat_completion_tools_handler( ) try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: raise e @@ -515,16 +519,16 @@ def is_chat_completion_request(request): ) -async def get_body_and_model_and_user(request): +async def get_body_and_model_and_user(request, models): # Read the original request body body = await request.body() body_str = body.decode("utf-8") body = json.loads(body_str) if body_str else {} model_id = body["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise Exception("Model not found") - model = app.state.MODELS[model_id] + model = models[model_id] user = get_current_user( request, @@ -540,14 +544,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + try: - body, model, user = await get_body_and_model_and_user(request) + body, model, user = await get_body_and_model_and_user(request, models) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) + model_info = Models.get_model_by_id(model["id"]) + if user.role == "user": + if model_info and not has_access( + user.id, type="read", access_control=model_info.access_control + ): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": "User does not have access to the model"}, + ) + metadata = { "chat_id": body.pop("chat_id", None), "message_id": body.pop("id", None), @@ -584,15 +601,20 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) + tool_ids = body.pop("tool_ids", None) + files = body.pop("files", None) + metadata = { **metadata, - "tool_ids": body.pop("tool_ids", None), - "files": body.pop("files", None), + "tool_ids": tool_ids, + "files": files, } body["metadata"] = metadata try: - body, flags = await chat_completion_tools_handler(body, user, extra_params) + body, flags = await chat_completion_tools_handler( + body, user, models, extra_params + ) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -689,10 +711,10 @@ app.add_middleware(ChatCompletionMiddleware) ################################## -def get_sorted_filters(model_id): +def get_sorted_filters(model_id, models): filters = [ model - for model in app.state.MODELS.values() + for model in models.values() if "pipeline" in model and "type" in model["pipeline"] and model["pipeline"]["type"] == "filter" @@ -708,12 +730,12 @@ def get_sorted_filters(model_id): return sorted_filters -def filter_pipeline(payload, user): +def filter_pipeline(payload, user, models): user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} model_id = payload["model"] - sorted_filters = get_sorted_filters(model_id) - model = app.state.MODELS[model_id] + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] if "pipeline" in model: sorted_filters.append(model) @@ -784,8 +806,11 @@ class PipelineMiddleware(BaseHTTPMiddleware): content={"detail": "Not authenticated"}, ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + try: - data = filter_pipeline(data, user) + data = filter_pipeline(data, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -864,16 +889,10 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - start_time = int(time.time()) response = await call_next(request) process_time = int(time.time()) - start_time response.headers["X-Process-Time"] = str(process_time) - return response @@ -913,12 +932,10 @@ app.mount("/retrieval/api/v1", retrieval_app) app.mount("/api/v1", webui_app) - webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION -async def get_all_models(): - # TODO: Optimize this function +async def get_all_base_models(): open_webui_models = [] openai_models = [] ollama_models = [] @@ -944,9 +961,15 @@ async def get_all_models(): open_webui_models = await get_open_webui_models() models = open_webui_models + openai_models + ollama_models + return models + + +@cached(ttl=1) +async def get_all_models(): + models = await get_all_base_models() # If there are no models, return an empty list - if len([model for model in models if model["owned_by"] != "arena"]) == 0: + if len([model for model in models if not model.get("arena", False)]) == 0: return [] global_action_ids = [ @@ -965,15 +988,23 @@ async def get_all_models(): custom_model.id == model["id"] or custom_model.id == model["id"].split(":")[0] ): - model["name"] = custom_model.name - model["info"] = custom_model.model_dump() + if custom_model.is_active: + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() - action_ids = [] - if "info" in model and "meta" in model["info"]: - action_ids.extend(model["info"]["meta"].get("actionIds", [])) + action_ids = [] + if "info" in model and "meta" in model["info"]: + action_ids.extend( + model["info"]["meta"].get("actionIds", []) + ) - model["action_ids"] = action_ids - else: + model["action_ids"] = action_ids + else: + models.remove(model) + + elif custom_model.is_active and ( + custom_model.id not in [model["id"] for model in models] + ): owned_by = "openai" pipe = None action_ids = [] @@ -995,7 +1026,7 @@ async def get_all_models(): models.append( { - "id": custom_model.id, + "id": f"{custom_model.id}", "name": custom_model.name, "object": "model", "created": custom_model.created_at, @@ -1007,66 +1038,54 @@ async def get_all_models(): } ) - for model in models: - action_ids = [] - if "action_ids" in model: - action_ids = model["action_ids"] - del model["action_ids"] + # Process action_ids to get the actions + def get_action_items_from_module(module): + actions = [] + if hasattr(module, "actions"): + actions = module.actions + return [ + { + "id": f"{module.id}.{action['id']}", + "name": action.get("name", f"{module.name} ({action['id']})"), + "description": module.meta.description, + "icon_url": action.get( + "icon_url", module.meta.manifest.get("icon_url", None) + ), + } + for action in actions + ] + else: + return [ + { + "id": module.id, + "name": module.name, + "description": module.meta.description, + "icon_url": module.meta.manifest.get("icon_url", None), + } + ] - action_ids = action_ids + global_action_ids - action_ids = list(set(action_ids)) + def get_function_module_by_id(function_id): + if function_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[function_id] + else: + function_module, _, _ = load_function_module_by_id(function_id) + webui_app.state.FUNCTIONS[function_id] = function_module + + for model in models: action_ids = [ - action_id for action_id in action_ids if action_id in enabled_action_ids + action_id + for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) + if action_id in enabled_action_ids ] model["actions"] = [] for action_id in action_ids: - action = Functions.get_function_by_id(action_id) - if action is None: + action_function = Functions.get_function_by_id(action_id) + if action_function is None: raise Exception(f"Action not found: {action_id}") - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - __webui__ = False - if hasattr(function_module, "__webui__"): - __webui__ = function_module.__webui__ - - if hasattr(function_module, "actions"): - actions = function_module.actions - model["actions"].extend( - [ - { - "id": f"{action_id}.{_action['id']}", - "name": _action.get( - "name", f"{action.name} ({_action['id']})" - ), - "description": action.meta.description, - "icon_url": _action.get( - "icon_url", action.meta.manifest.get("icon_url", None) - ), - **({"__webui__": __webui__} if __webui__ else {}), - } - for _action in actions - ] - ) - else: - model["actions"].append( - { - "id": action_id, - "name": action.name, - "description": action.meta.description, - "icon_url": action.meta.manifest.get("icon_url", None), - **({"__webui__": __webui__} if __webui__ else {}), - } - ) - - app.state.MODELS = {model["id"]: model for model in models} - webui_app.state.MODELS = app.state.MODELS - + function_module = get_function_module_by_id(action_id) + model["actions"].extend(get_action_items_from_module(function_module)) return models @@ -1081,40 +1100,58 @@ async def get_models(user=Depends(get_verified_user)): if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" ] - if app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user": - models = list( - filter( - lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - models, - ) - ) - return {"data": models} + # Filter out models that the user does not have access to + if user.role == "user": + filtered_models = [] + for model in models: + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + else: + filtered_models.append(model) + models = filtered_models return {"data": models} +@app.get("/api/models/base") +async def get_base_models(user=Depends(get_admin_user)): + models = await get_all_base_models() + + # Filter out arena models + models = [model for model in models if not model.get("arena", False)] + return {"data": models} + + @app.post("/api/chat/completions") async def generate_chat_completions( form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False ): - model_id = form_data["model"] + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} - if model_id not in app.state.MODELS: + model_id = form_data["model"] + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: - if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: + model = models[model_id] + # Check if user has access to the model + if user.role == "user": + model_info = Models.get_model_by_id(model_id) + if not has_access( + user.id, type="read", access_control=model_info.access_control + ): raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, + status_code=403, detail="Model not found", ) - model = app.state.MODELS[model_id] - if model["owned_by"] == "arena": model_ids = model.get("info", {}).get("meta", {}).get("model_ids") filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") @@ -1161,14 +1198,18 @@ async def generate_chat_completions( ), "selected_model_id": selected_model_id, } + if model.get("pipe"): - return await generate_function_chat_completion(form_data, user=user) + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) if model["owned_by"] == "ollama": # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) form_data = GenerateChatCompletionForm(**form_data) response = await generate_ollama_chat_completion( - form_data=form_data, user=user, bypass_filter=True + form_data=form_data, user=user, bypass_filter=bypass_filter ) if form_data.stream: response.headers["content-type"] = "text/event-stream" @@ -1179,21 +1220,27 @@ async def generate_chat_completions( else: return convert_response_ollama_to_openai(response) else: - return await generate_openai_chat_completion(form_data, user=user) + return await generate_openai_chat_completion( + form_data, user=user, bypass_filter=bypass_filter + ) @app.post("/api/chat/completed") async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + data = form_data model_id = data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - model = app.state.MODELS[model_id] - sorted_filters = get_sorted_filters(model_id) + model = models[model_id] + sorted_filters = get_sorted_filters(model_id, models) if "pipeline" in model: sorted_filters = [model] + sorted_filters @@ -1368,14 +1415,18 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified detail="Action not found", ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + data = form_data model_id = data["model"] - if model_id not in app.state.MODELS: + + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - model = app.state.MODELS[model_id] + model = models[model_id] __event_emitter__ = get_event_emitter( { @@ -1529,8 +1580,11 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u async def generate_title(form_data: dict, user=Depends(get_verified_user)): print("generate_title") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1538,10 +1592,16 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE @@ -1575,7 +1635,7 @@ Artificial Intelligence in Healthcare "stream": False, **( {"max_tokens": 50} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 50, } @@ -1587,7 +1647,7 @@ Artificial Intelligence in Healthcare # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1614,8 +1674,11 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): content={"detail": "Tags generation is disabled"}, ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1623,7 +1686,12 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": @@ -1661,7 +1729,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1688,8 +1756,11 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) detail=f"Search query generation is disabled", ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1697,10 +1768,15 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE @@ -1727,7 +1803,7 @@ Search Query:""" "stream": False, **( {"max_tokens": 30} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 30, } @@ -1738,7 +1814,7 @@ Search Query:""" # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1760,8 +1836,11 @@ Search Query:""" async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): print("generate_emoji") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1769,10 +1848,15 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] template = ''' Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). @@ -1794,7 +1878,7 @@ Message: """{{prompt}}""" "stream": False, **( {"max_tokens": 4} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 4, } @@ -1806,7 +1890,7 @@ Message: """{{prompt}}""" # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1828,8 +1912,11 @@ Message: """{{prompt}}""" async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): print("generate_moa_response") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1837,10 +1924,15 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user) # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" @@ -1867,7 +1959,7 @@ Responses from models: {{responses}}""" log.debug(payload) try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1897,7 +1989,7 @@ Responses from models: {{responses}}""" @app.get("/api/pipelines/list") async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models(raw=True) + responses = await get_openai_models_responses() print(responses) urlIdxs = [ @@ -2297,32 +2389,6 @@ async def get_app_config(request: Request): } -@app.get("/api/config/model/filter") -async def get_model_filter_config(user=Depends(get_admin_user)): - return { - "enabled": app.state.config.ENABLE_MODEL_FILTER, - "models": app.state.config.MODEL_FILTER_LIST, - } - - -class ModelFilterConfigForm(BaseModel): - enabled: bool - models: list[str] - - -@app.post("/api/config/model/filter") -async def update_model_filter_config( - form_data: ModelFilterConfigForm, user=Depends(get_admin_user) -): - app.state.config.ENABLE_MODEL_FILTER = form_data.enabled - app.state.config.MODEL_FILTER_LIST = form_data.models - - return { - "enabled": app.state.config.ENABLE_MODEL_FILTER, - "models": app.state.config.MODEL_FILTER_LIST, - } - - # TODO: webhook endpoint should be under config endpoints diff --git a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py new file mode 100644 index 0000000000..a752115844 --- /dev/null +++ b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py @@ -0,0 +1,85 @@ +"""Add group table + +Revision ID: 922e7a387820 +Revises: 4ace53fd72c8 +Create Date: 2024-11-14 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "922e7a387820" +down_revision = "4ace53fd72c8" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "group", + sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column("user_id", sa.Text(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("permissions", sa.JSON(), nullable=True), + sa.Column("user_ids", sa.JSON(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + ) + + # Add 'access_control' column to 'model' table + op.add_column( + "model", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'is_active' column to 'model' table + op.add_column( + "model", + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.sql.expression.true(), + ), + ) + + # Add 'access_control' column to 'knowledge' table + op.add_column( + "knowledge", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'access_control' column to 'prompt' table + op.add_column( + "prompt", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + # Add 'access_control' column to 'tools' table + op.add_column( + "tool", + sa.Column("access_control", sa.JSON(), nullable=True), + ) + + +def downgrade(): + op.drop_table("group") + + # Drop 'access_control' column from 'model' table + op.drop_column("model", "access_control") + + # Drop 'is_active' column from 'model' table + op.drop_column("model", "is_active") + + # Drop 'access_control' column from 'knowledge' table + op.drop_column("knowledge", "access_control") + + # Drop 'access_control' column from 'prompt' table + op.drop_column("prompt", "access_control") + + # Drop 'access_control' column from 'tools' table + op.drop_column("tool", "access_control") diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py new file mode 100644 index 0000000000..270b28bcc2 --- /dev/null +++ b/backend/open_webui/utils/access_control.py @@ -0,0 +1,95 @@ +from typing import Optional, Union, List, Dict, Any +from open_webui.apps.webui.models.groups import Groups +import json + + +def get_permissions( + user_id: str, + default_permissions: Dict[str, Any], +) -> Dict[str, Any]: + """ + Get all permissions for a user by combining the permissions of all groups the user is a member of. + If a permission is defined in multiple groups, the most permissive value is used (True > False). + Permissions are nested in a dict with the permission key as the key and a boolean as the value. + """ + + def combine_permissions( + permissions: Dict[str, Any], group_permissions: Dict[str, Any] + ) -> Dict[str, Any]: + """Combine permissions from multiple groups by taking the most permissive value.""" + for key, value in group_permissions.items(): + if isinstance(value, dict): + if key not in permissions: + permissions[key] = {} + permissions[key] = combine_permissions(permissions[key], value) + else: + if key not in permissions: + permissions[key] = value + else: + permissions[key] = permissions[key] or value + return permissions + + user_groups = Groups.get_groups_by_member_id(user_id) + + # deep copy default permissions to avoid modifying the original dict + permissions = json.loads(json.dumps(default_permissions)) + + for group in user_groups: + group_permissions = group.permissions + permissions = combine_permissions(permissions, group_permissions) + + return permissions + + +def has_permission( + user_id: str, + permission_key: str, + default_permissions: Dict[str, bool] = {}, +) -> bool: + """ + Check if a user has a specific permission by checking the group permissions + and falls back to default permissions if not found in any group. + + Permission keys can be hierarchical and separated by dots ('.'). + """ + + def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool: + """Traverse permissions dict using a list of keys (from dot-split permission_key).""" + for key in keys: + if key not in permissions: + return False # If any part of the hierarchy is missing, deny access + permissions = permissions[key] # Go one level deeper + + return bool(permissions) # Return the boolean at the final level + + permission_hierarchy = permission_key.split(".") + + # Retrieve user group permissions + user_groups = Groups.get_groups_by_member_id(user_id) + + for group in user_groups: + group_permissions = group.permissions + if get_permission(group_permissions, permission_hierarchy): + return True + + # Check default permissions afterwards if the group permissions don't allow it + return get_permission(default_permissions, permission_hierarchy) + + +def has_access( + user_id: str, + type: str = "write", + access_control: Optional[dict] = None, +) -> bool: + if access_control is None: + return type == "read" + + user_groups = Groups.get_groups_by_member_id(user_id) + user_group_ids = [group.id for group in user_groups] + permission_access = access_control.get(type, {}) + permitted_group_ids = permission_access.get("group_ids", []) + permitted_user_ids = permission_access.get("user_ids", []) + + return user_id in permitted_user_ids or any( + group_id in permitted_group_ids for group_id in user_group_ids + ) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 0b57eb35b6..e77386ac4b 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -4,7 +4,7 @@ from typing import Awaitable, Callable, get_type_hints from open_webui.apps.webui.models.tools import Tools from open_webui.apps.webui.models.users import UserModel -from open_webui.apps.webui.utils import load_toolkit_module_by_id +from open_webui.apps.webui.utils import load_tools_module_by_id from open_webui.utils.schemas import json_schema_to_model log = logging.getLogger(__name__) @@ -32,15 +32,16 @@ def apply_extra_params_to_tool_function( def get_tools( webui_app, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: - tools = {} + tools_dict = {} + for tool_id in tool_ids: - toolkit = Tools.get_tool_by_id(tool_id) - if toolkit is None: + tools = Tools.get_tool_by_id(tool_id) + if tools is None: continue module = webui_app.state.TOOLS.get(tool_id, None) if module is None: - module, _ = load_toolkit_module_by_id(tool_id) + module, _ = load_tools_module_by_id(tool_id) webui_app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id @@ -53,11 +54,19 @@ def get_tools( **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) - for spec in toolkit.specs: + for spec in tools.specs: # TODO: Fix hack for OpenAI API for val in spec.get("parameters", {}).get("properties", {}).values(): if val["type"] == "str": val["type"] = "string" + + # Remove internal parameters + spec["parameters"]["properties"] = { + key: val + for key, val in spec["parameters"]["properties"].items() + if not key.startswith("__") + } + function_name = spec["name"] # convert to function that takes only model params and inserts custom params @@ -77,13 +86,14 @@ def get_tools( } # TODO: if collision, prepend toolkit name - if function_name in tools: - log.warning(f"Tool {function_name} already exists in another toolkit!") - log.warning(f"Collision between {toolkit} and {tool_id}.") - log.warning(f"Discarding {toolkit}.{function_name}") + if function_name in tools_dict: + log.warning(f"Tool {function_name} already exists in another tools!") + log.warning(f"Collision between {tools} and {tool_id}.") + log.warning(f"Discarding {tools}.{function_name}") else: - tools[function_name] = tool_dict - return tools + tools_dict[function_name] = tool_dict + + return tools_dict def doc_to_dict(docstring): diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/utils.py index 31fe227ede..1c2205ebff 100644 --- a/backend/open_webui/utils/utils.py +++ b/backend/open_webui/utils/utils.py @@ -1,12 +1,17 @@ import logging import uuid -from datetime import UTC, datetime, timedelta -from typing import Optional, Union - import jwt + +from datetime import UTC, datetime, timedelta +from typing import Optional, Union, List, Dict + + from open_webui.apps.webui.models.users import Users + from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SECRET_KEY + + from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from passlib.context import CryptContext diff --git a/backend/requirements.txt b/backend/requirements.txt index 44838dd36f..a5bfae5855 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,6 +13,7 @@ passlib[bcrypt]==1.7.4 requests==2.32.3 aiohttp==3.10.8 async-timeout +aiocache sqlalchemy==2.0.32 alembic==1.13.2 diff --git a/pyproject.toml b/pyproject.toml index 305ced6eb1..fa16381f21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "requests==2.32.3", "aiohttp==3.10.8", "async-timeout", + "aiocache", "sqlalchemy==2.0.32", "alembic==1.13.2", diff --git a/src/lib/apis/groups/index.ts b/src/lib/apis/groups/index.ts new file mode 100644 index 0000000000..a846fc46f7 --- /dev/null +++ b/src/lib/apis/groups/index.ts @@ -0,0 +1,163 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewGroup = async (token: string, group: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...group + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getGroups = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + +export const getGroupById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateGroupById = async (token: string, id: string, group: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...group + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteGroupById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 40d0e03924..7d7ca0e2df 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,9 +1,10 @@ import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; -export const getModels = async (token: string = '') => { +export const getModels = async (token: string = '', base: boolean = false) => { let error = null; - - const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/models${ + base ? '/base' : '' + }`, { method: 'GET', headers: { Accept: 'application/json', @@ -16,36 +17,21 @@ export const getModels = async (token: string = '') => { return res.json(); }) .catch((err) => { - console.log(err); error = err; + console.log(err); return null; }); + if (error) { throw error; } let models = res?.data ?? []; - models = models .filter((models) => models) // Sort the models .sort((a, b) => { - // Check if models have position property - const aHasPosition = a.info?.meta?.position !== undefined; - const bHasPosition = b.info?.meta?.position !== undefined; - - // If both a and b have the position property - if (aHasPosition && bHasPosition) { - return a.info.meta.position - b.info.meta.position; - } - - // If only a has the position property, it should come first - if (aHasPosition) return -1; - - // If only b has the position property, it should come first - if (bHasPosition) return 1; - // Compare case-insensitively by name for models without position property const lowerA = a.name.toLowerCase(); const lowerB = b.name.toLowerCase(); diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index 8428668996..da2b9d530e 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -1,6 +1,6 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewKnowledge = async (token: string, name: string, description: string) => { +export const createNewKnowledge = async (token: string, name: string, description: string, accessControl: null|object) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/create`, { @@ -12,7 +12,8 @@ export const createNewKnowledge = async (token: string, name: string, descriptio }, body: JSON.stringify({ name: name, - description: description + description: description, + access_control: accessControl }) }) .then(async (res) => { @@ -32,7 +33,7 @@ export const createNewKnowledge = async (token: string, name: string, descriptio return res; }; -export const getKnowledgeItems = async (token: string = '') => { +export const getKnowledgeBases = async (token: string = '') => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/`, { @@ -63,6 +64,37 @@ export const getKnowledgeItems = async (token: string = '') => { return res; }; +export const getKnowledgeBaseList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getKnowledgeById = async (token: string, id: string) => { let error = null; @@ -99,6 +131,7 @@ type KnowledgeUpdateForm = { name?: string; description?: string; data?: object; + access_control?: null|object; }; export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => { @@ -114,7 +147,8 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl body: JSON.stringify({ name: form?.name ? form.name : undefined, description: form?.description ? form.description : undefined, - data: form?.data ? form.data : undefined + data: form?.data ? form.data : undefined, + access_control: form.access_control }) }) .then(async (res) => { diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 9faa358d33..90ec3286db 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,35 +1,7 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const addNewModel = async (token: string, model: object) => { - let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify(model) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const getModelInfos = async (token: string = '') => { +export const getModels = async (token: string = '') => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/models`, { @@ -60,13 +32,79 @@ export const getModelInfos = async (token: string = '') => { return res; }; + + +export const getBaseModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/base`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + + +export const createNewModel = async (token: string, model: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(model) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + + export const getModelById = async (token: string, id: string) => { let error = null; const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}`, { method: 'GET', headers: { Accept: 'application/json', @@ -95,13 +133,50 @@ export const getModelById = async (token: string, id: string) => { return res; }; + +export const toggleModelById = async (token: string, id: string) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/toggle`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + export const updateModelById = async (token: string, id: string, model: object) => { let error = null; const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/update?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -137,7 +212,7 @@ export const deleteModelById = async (token: string, id: string) => { const searchParams = new URLSearchParams(); searchParams.append('id', id); - const res = await fetch(`${WEBUI_API_BASE_URL}/models/delete?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/delete`, { method: 'DELETE', headers: { Accept: 'application/json', diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 4ad2ba5968..84f3106c27 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -211,10 +211,12 @@ export const getOllamaVersion = async (token: string, urlIdx?: number) => { return res?.version ?? false; }; -export const getOllamaModels = async (token: string = '') => { +export const getOllamaModels = async (token: string = '', urlIdx: null|number = null) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags${ + urlIdx !== null ? `/${urlIdx}` : '' + }`, { method: 'GET', headers: { Accept: 'application/json', diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts index ca9c7d543d..bd7741e481 100644 --- a/src/lib/apis/prompts/index.ts +++ b/src/lib/apis/prompts/index.ts @@ -1,10 +1,18 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; + +type PromptItem = { + command: string; + title: string; + content: string; + access_control: null|object; +} + + + export const createNewPrompt = async ( token: string, - command: string, - title: string, - content: string + prompt: PromptItem ) => { let error = null; @@ -16,9 +24,8 @@ export const createNewPrompt = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - command: `/${command}`, - title: title, - content: content + ...prompt, + command: `/${prompt.command}`, }) }) .then(async (res) => { @@ -69,6 +76,39 @@ export const getPrompts = async (token: string = '') => { return res; }; + +export const getPromptList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + export const getPromptByCommand = async (token: string, command: string) => { let error = null; @@ -101,15 +141,15 @@ export const getPromptByCommand = async (token: string, command: string) => { return res; }; + + export const updatePromptByCommand = async ( token: string, - command: string, - title: string, - content: string + prompt: PromptItem ) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${command}/update`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${prompt.command}/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -117,9 +157,8 @@ export const updatePromptByCommand = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - command: `/${command}`, - title: title, - content: content + ...prompt, + command: `/${prompt.command}`, }) }) .then(async (res) => { diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts index 28e8dde86c..25dcf03aa0 100644 --- a/src/lib/apis/tools/index.ts +++ b/src/lib/apis/tools/index.ts @@ -62,6 +62,39 @@ export const getTools = async (token: string = '') => { return res; }; + +export const getToolList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + export const exportTools = async (token: string = '') => { let error = null; diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index 0b22b71715..5c95c7cdf2 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -1,10 +1,11 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; import { getUserPosition } from '$lib/utils'; -export const getUserPermissions = async (token: string) => { + +export const getUserGroups = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/users/permissions/user`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/users/groups`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -28,10 +29,39 @@ export const getUserPermissions = async (token: string) => { return res; }; -export const updateUserPermissions = async (token: string, permissions: object) => { + + +export const getUserDefaultPermissions = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/users/permissions/user`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/users/default/permissions`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateUserDefaultPermissions = async (token: string, permissions: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/users/default/permissions`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/admin/Functions.svelte similarity index 93% rename from src/lib/components/workspace/Functions.svelte rename to src/lib/components/admin/Functions.svelte index 08e1e28a12..0dfcbb73bc 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/admin/Functions.svelte @@ -5,7 +5,6 @@ import { WEBUI_NAME, config, functions, models } from '$lib/stores'; import { onMount, getContext, tick } from 'svelte'; - import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; import { @@ -25,13 +24,14 @@ import FunctionMenu from './Functions/FunctionMenu.svelte'; import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte'; import Switch from '../common/Switch.svelte'; - import ValvesModal from './common/ValvesModal.svelte'; - import ManifestModal from './common/ManifestModal.svelte'; + import ValvesModal from '../workspace/common/ValvesModal.svelte'; + import ManifestModal from '../workspace/common/ManifestModal.svelte'; import Heart from '../icons/Heart.svelte'; import DeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import GarbageBin from '../icons/GarbageBin.svelte'; import Search from '../icons/Search.svelte'; import Plus from '../icons/Plus.svelte'; + import ChevronRight from '../icons/ChevronRight.svelte'; const i18n = getContext('i18n'); @@ -98,7 +98,7 @@ id: `${_function.id}_clone`, name: `${_function.name} (Clone)` }); - goto('/workspace/functions/create'); + goto('/admin/functions/create'); } }; @@ -210,7 +210,7 @@