diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 701dda856f..c262d857ba 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1203,6 +1203,12 @@ DEFAULT_USER_ROLE = PersistentConfig( os.getenv("DEFAULT_USER_ROLE", "pending"), ) +DEFAULT_GROUP_ID = PersistentConfig( + "DEFAULT_GROUP_ID", + "ui.default_group_id", + os.environ.get("DEFAULT_GROUP_ID", ""), +) + PENDING_USER_OVERLAY_TITLE = PersistentConfig( "PENDING_USER_OVERLAY_TITLE", "ui.pending_user_overlay_title", @@ -1270,6 +1276,12 @@ USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT = ( os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT", "False").lower() == "true" ) + +USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING", "False").lower() + == "true" +) + USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = ( os.environ.get( "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", "False" @@ -1277,8 +1289,10 @@ USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = ( == "true" ) -USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = ( - os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower() +USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING = ( + os.environ.get( + "USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING", "False" + ).lower() == "true" ) @@ -1289,6 +1303,11 @@ USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING = ( == "true" ) +USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING", "False").lower() + == "true" +) + USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = ( os.environ.get( "USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING", "False" @@ -1296,6 +1315,12 @@ USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = ( == "true" ) + +USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING", "False").lower() + == "true" +) + USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = ( os.environ.get( "USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING", "False" @@ -1304,6 +1329,17 @@ USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = ( ) +USER_PERMISSIONS_NOTES_ALLOW_SHARING = ( + os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower() + == "true" +) + +USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = ( + os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower() + == "true" +) + + USER_PERMISSIONS_CHAT_CONTROLS = ( os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true" ) @@ -1425,10 +1461,15 @@ DEFAULT_USER_PERMISSIONS = { "tools_export": USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT, }, "sharing": { + "models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING, "public_models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING, + "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING, "public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING, + "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING, "public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING, + "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING, "public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING, + "notes": USER_PERMISSIONS_NOTES_ALLOW_SHARING, "public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING, }, "chat": { @@ -2145,6 +2186,11 @@ ENABLE_QDRANT_MULTITENANCY_MODE = ( ) QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui") +WEAVIATE_HTTP_HOST = os.environ.get("WEAVIATE_HTTP_HOST", "") +WEAVIATE_HTTP_PORT = int(os.environ.get("WEAVIATE_HTTP_PORT", "8080")) +WEAVIATE_GRPC_PORT = int(os.environ.get("WEAVIATE_GRPC_PORT", "50051")) +WEAVIATE_API_KEY = os.environ.get("WEAVIATE_API_KEY") + # OpenSearch OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true" @@ -3499,6 +3545,11 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig( os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""), ) +ENABLE_IMAGE_EDIT = PersistentConfig( + "ENABLE_IMAGE_EDIT", + "images.edit.enable", + os.environ.get("ENABLE_IMAGE_EDIT", "").lower() == "true", +) IMAGE_EDIT_ENGINE = PersistentConfig( "IMAGE_EDIT_ENGINE", diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 6d63295ab8..4d39d16cdb 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -45,7 +45,7 @@ class ERROR_MESSAGES(str, Enum): ) INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." - INVALID_PASSWORD = ( + INCORRECT_PASSWORD = ( "The password provided is incorrect. Please check for typos and try again." ) INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance." @@ -105,6 +105,10 @@ class ERROR_MESSAGES(str, Enum): ) FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding." + INVALID_PASSWORD = lambda err="": ( + err if err else "The password does not meet the required validation criteria." + ) + class TASKS(str, Enum): def __str__(self) -> str: diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 7059780220..651629b950 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -8,6 +8,8 @@ import shutil from uuid import uuid4 from pathlib import Path from cryptography.hazmat.primitives import serialization +import re + import markdown from bs4 import BeautifulSoup @@ -135,6 +137,9 @@ else: PACKAGE_DATA = {"version": "0.0.0"} VERSION = PACKAGE_DATA["version"] + + +DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "") INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4())) @@ -426,6 +431,17 @@ WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get( ) +ENABLE_PASSWORD_VALIDATION = ( + os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true" +) +PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get( + "PASSWORD_VALIDATION_REGEX_PATTERN", + "^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$", +) + +PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN) + + BYPASS_MODEL_ACCESS_CONTROL = ( os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" ) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 18b0dee1f9..13bcc360ea 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -164,6 +164,7 @@ from open_webui.config import ( IMAGES_GEMINI_API_BASE_URL, IMAGES_GEMINI_API_KEY, IMAGES_GEMINI_ENDPOINT_METHOD, + ENABLE_IMAGE_EDIT, IMAGE_EDIT_ENGINE, IMAGE_EDIT_MODEL, IMAGE_EDIT_SIZE, @@ -369,6 +370,7 @@ from open_webui.config import ( BYPASS_ADMIN_ACCESS_CONTROL, USER_PERMISSIONS, DEFAULT_USER_ROLE, + DEFAULT_GROUP_ID, PENDING_USER_OVERLAY_CONTENT, PENDING_USER_OVERLAY_TITLE, DEFAULT_PROMPT_SUGGESTIONS, @@ -455,6 +457,7 @@ from open_webui.env import ( SAFE_MODE, SRC_LOG_LEVELS, VERSION, + DEPLOYMENT_ID, INSTANCE_ID, WEBUI_BUILD_HASH, WEBUI_SECRET_KEY, @@ -762,6 +765,7 @@ app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE +app.state.config.DEFAULT_GROUP_ID = DEFAULT_GROUP_ID app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE @@ -1116,6 +1120,7 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES +app.state.config.ENABLE_IMAGE_EDIT = ENABLE_IMAGE_EDIT app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE @@ -1451,6 +1456,10 @@ async def get_models( if "pipeline" in model and model["pipeline"].get("type", None) == "filter": continue + # Remove profile image URL to reduce payload size + if model.get("info", {}).get("meta", {}).get("profile_image_url"): + model["info"]["meta"].pop("profile_image_url", None) + try: model_tags = [ tag.get("name") @@ -1986,6 +1995,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): async def get_app_version(): return { "version": VERSION, + "deployment_id": DEPLOYMENT_ID, } diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 215e36aa24..33f7f6179a 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -4,7 +4,7 @@ import uuid from typing import Optional from open_webui.internal.db import Base, get_db -from open_webui.models.chats import Chats +from open_webui.models.users import User from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict @@ -92,6 +92,28 @@ class FeedbackForm(BaseModel): model_config = ConfigDict(extra="allow") +class UserResponse(BaseModel): + id: str + name: str + email: str + role: str = "pending" + + last_active_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + model_config = ConfigDict(from_attributes=True) + + +class FeedbackUserResponse(FeedbackResponse): + user: Optional[UserResponse] = None + + +class FeedbackListResponse(BaseModel): + items: list[FeedbackUserResponse] + total: int + + class FeedbackTable: def insert_new_feedback( self, user_id: str, form_data: FeedbackForm @@ -143,6 +165,70 @@ class FeedbackTable: except Exception: return None + def get_feedback_items( + self, filter: dict = {}, skip: int = 0, limit: int = 30 + ) -> FeedbackListResponse: + with get_db() as db: + query = db.query(Feedback, User).join(User, Feedback.user_id == User.id) + + if filter: + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by == "username": + if direction == "asc": + query = query.order_by(User.name.asc()) + else: + query = query.order_by(User.name.desc()) + elif order_by == "model_id": + # it's stored in feedback.data['model_id'] + if direction == "asc": + query = query.order_by( + Feedback.data["model_id"].as_string().asc() + ) + else: + query = query.order_by( + Feedback.data["model_id"].as_string().desc() + ) + elif order_by == "rating": + # it's stored in feedback.data['rating'] + if direction == "asc": + query = query.order_by( + Feedback.data["rating"].as_string().asc() + ) + else: + query = query.order_by( + Feedback.data["rating"].as_string().desc() + ) + elif order_by == "updated_at": + if direction == "asc": + query = query.order_by(Feedback.updated_at.asc()) + else: + query = query.order_by(Feedback.updated_at.desc()) + + else: + query = query.order_by(Feedback.created_at.desc()) + + # Count BEFORE pagination + total = query.count() + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + items = query.all() + + feedbacks = [] + for feedback, user in items: + feedback_model = FeedbackModel.model_validate(feedback) + user_model = UserResponse.model_validate(user) + feedbacks.append( + FeedbackUserResponse(**feedback_model.model_dump(), user=user_model) + ) + + return FeedbackListResponse(items=feedbacks, total=total) + def get_all_feedbacks(self) -> list[FeedbackModel]: with get_db() as db: return [ diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index faf2769a8a..1d96f5cfaa 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -101,6 +101,7 @@ class GroupForm(BaseModel): name: str description: str permissions: Optional[dict] = None + data: Optional[dict] = None class UserIdsForm(BaseModel): diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 93dafe0f05..f5964c0579 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -244,11 +244,9 @@ class ModelsTable: try: with get_db() as db: # update only the fields that are present in the model - result = ( - db.query(Model) - .filter_by(id=id) - .update(model.model_dump(exclude={"id"})) - ) + data = model.model_dump(exclude={"id"}) + result = db.query(Model).filter_by(id=id).update(data) + db.commit() model = db.get(Model, id) diff --git a/backend/open_webui/retrieval/vector/dbs/weaviate.py b/backend/open_webui/retrieval/vector/dbs/weaviate.py new file mode 100644 index 0000000000..b90d24b499 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/weaviate.py @@ -0,0 +1,301 @@ +import weaviate +import re +import uuid +from typing import Any, Dict, List, Optional, Union + +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) +from open_webui.retrieval.vector.utils import process_metadata +from open_webui.config import WEAVIATE_HTTP_HOST, WEAVIATE_HTTP_PORT, WEAVIATE_GRPC_PORT, WEAVIATE_API_KEY + + +def _convert_uuids_to_strings(obj: Any) -> Any: + """ + Recursively convert UUID objects to strings in nested data structures. + + This function handles: + - UUID objects -> string + - Dictionaries with UUID values + - Lists/Tuples with UUID values + - Nested combinations of the above + + Args: + obj: Any object that might contain UUIDs + + Returns: + The same object structure with UUIDs converted to strings + """ + if isinstance(obj, uuid.UUID): + return str(obj) + elif isinstance(obj, dict): + return {key: _convert_uuids_to_strings(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(_convert_uuids_to_strings(item) for item in obj) + elif isinstance(obj, (str, int, float, bool, type(None))): + return obj + else: + return obj + + + + +class WeaviateClient(VectorDBBase): + def __init__(self): + self.url = WEAVIATE_HTTP_HOST + try: + # Build connection parameters + connection_params = { + "host": WEAVIATE_HTTP_HOST, + "port": WEAVIATE_HTTP_PORT, + "grpc_port": WEAVIATE_GRPC_PORT, + } + + # Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty + if WEAVIATE_API_KEY: + connection_params["auth_credentials"] = weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY) + + self.client = weaviate.connect_to_local(**connection_params) + self.client.connect() + except Exception as e: + raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e + + def _sanitize_collection_name(self, collection_name: str) -> str: + """Sanitize collection name to be a valid Weaviate class name.""" + if not isinstance(collection_name, str) or not collection_name.strip(): + raise ValueError("Collection name must be a non-empty string") + + # Requirements for a valid Weaviate class name: + # The collection name must begin with a capital letter. + # The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed. + + # Replace hyphens with underscores and keep only alphanumeric characters + name = re.sub(r'[^a-zA-Z0-9_]', '', collection_name.replace("-", "_")) + name = name.strip("_") + + if not name: + raise ValueError("Could not sanitize collection name to be a valid Weaviate class name") + + # Ensure it starts with a letter and is capitalized + if not name[0].isalpha(): + name = "C" + name + + return name[0].upper() + name[1:] + + def has_collection(self, collection_name: str) -> bool: + sane_collection_name = self._sanitize_collection_name(collection_name) + return self.client.collections.exists(sane_collection_name) + + def delete_collection(self, collection_name: str) -> None: + sane_collection_name = self._sanitize_collection_name(collection_name) + if self.client.collections.exists(sane_collection_name): + self.client.collections.delete(sane_collection_name) + + def _create_collection(self, collection_name: str) -> None: + self.client.collections.create( + name=collection_name, + vector_config=weaviate.classes.config.Configure.Vectors.self_provided(), + properties=[ + weaviate.classes.config.Property(name="text", data_type=weaviate.classes.config.DataType.TEXT), + ] + ) + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + sane_collection_name = self._sanitize_collection_name(collection_name) + if not self.client.collections.exists(sane_collection_name): + self._create_collection(sane_collection_name) + + collection = self.client.collections.get(sane_collection_name) + + with collection.batch.fixed_size(batch_size=100) as batch: + for item in items: + item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"]) + + properties = {"text": item["text"]} + if item["metadata"]: + clean_metadata = _convert_uuids_to_strings(process_metadata(item["metadata"])) + clean_metadata.pop("text", None) + properties.update(clean_metadata) + + batch.add_object( + properties=properties, + uuid=item_uuid, + vector=item["vector"] + ) + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + sane_collection_name = self._sanitize_collection_name(collection_name) + if not self.client.collections.exists(sane_collection_name): + self._create_collection(sane_collection_name) + + collection = self.client.collections.get(sane_collection_name) + + with collection.batch.fixed_size(batch_size=100) as batch: + for item in items: + item_uuid = str(item["id"]) if item["id"] else None + + properties = {"text": item["text"]} + if item["metadata"]: + clean_metadata = _convert_uuids_to_strings(process_metadata(item["metadata"])) + clean_metadata.pop("text", None) + properties.update(clean_metadata) + + batch.add_object( + properties=properties, + uuid=item_uuid, + vector=item["vector"] + ) + + def search( + self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + ) -> Optional[SearchResult]: + sane_collection_name = self._sanitize_collection_name(collection_name) + if not self.client.collections.exists(sane_collection_name): + return None + + collection = self.client.collections.get(sane_collection_name) + + result_ids, result_documents, result_metadatas, result_distances = [], [], [], [] + + for vector_embedding in vectors: + try: + response = collection.query.near_vector( + near_vector=vector_embedding, + limit=limit, + return_metadata=weaviate.classes.query.MetadataQuery(distance=True), + ) + + ids = [str(obj.uuid) for obj in response.objects] + documents = [] + metadatas = [] + distances = [] + + for obj in response.objects: + properties = dict(obj.properties) if obj.properties else {} + documents.append(properties.pop("text", "")) + metadatas.append(_convert_uuids_to_strings(properties)) + + # Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1 + raw_distances = [obj.metadata.distance if obj.metadata and obj.metadata.distance else 2.0 for obj in response.objects] + distances = [(2 - dist) / 2 for dist in raw_distances] + + result_ids.append(ids) + result_documents.append(documents) + result_metadatas.append(metadatas) + result_distances.append(distances) + except Exception: + result_ids.append([]) + result_documents.append([]) + result_metadatas.append([]) + result_distances.append([]) + + return SearchResult( + **{ + "ids": result_ids, + "documents": result_documents, + "metadatas": result_metadatas, + "distances": result_distances, + } + ) + + def query( + self, collection_name: str, filter: Dict, limit: Optional[int] = None + ) -> Optional[GetResult]: + sane_collection_name = self._sanitize_collection_name(collection_name) + if not self.client.collections.exists(sane_collection_name): + return None + + collection = self.client.collections.get(sane_collection_name) + + weaviate_filter = None + if filter: + for key, value in filter.items(): + prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value) + weaviate_filter = prop_filter if weaviate_filter is None else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter]) + + try: + response = collection.query.fetch_objects(filters=weaviate_filter, limit=limit) + + ids = [str(obj.uuid) for obj in response.objects] + documents = [] + metadatas = [] + + for obj in response.objects: + properties = dict(obj.properties) if obj.properties else {} + documents.append(properties.pop("text", "")) + metadatas.append(_convert_uuids_to_strings(properties)) + + return GetResult( + **{ + "ids": [ids], + "documents": [documents], + "metadatas": [metadatas], + } + ) + except Exception: + return None + + def get(self, collection_name: str) -> Optional[GetResult]: + sane_collection_name = self._sanitize_collection_name(collection_name) + if not self.client.collections.exists(sane_collection_name): + return None + + collection = self.client.collections.get(sane_collection_name) + ids, documents, metadatas = [], [], [] + + try: + for item in collection.iterator(): + ids.append(str(item.uuid)) + properties = dict(item.properties) if item.properties else {} + documents.append(properties.pop("text", "")) + metadatas.append(_convert_uuids_to_strings(properties)) + + if not ids: + return None + + return GetResult( + **{ + "ids": [ids], + "documents": [documents], + "metadatas": [metadatas], + } + ) + except Exception: + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict] = None, + ) -> None: + sane_collection_name = self._sanitize_collection_name(collection_name) + if not self.client.collections.exists(sane_collection_name): + return + + collection = self.client.collections.get(sane_collection_name) + + try: + if ids: + for item_id in ids: + collection.data.delete_by_id(uuid=item_id) + elif filter: + weaviate_filter = None + for key, value in filter.items(): + prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value) + weaviate_filter = prop_filter if weaviate_filter is None else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter]) + + if weaviate_filter: + collection.data.delete_many(where=weaviate_filter) + except Exception: + pass + + def reset(self) -> None: + try: + for collection_name in self.client.collections.list_all().keys(): + self.client.collections.delete(collection_name) + except Exception: + pass diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index 7888c22be8..b843e0926d 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -67,6 +67,10 @@ class Vector: from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient return Oracle23aiClient() + case VectorType.WEAVIATE: + from open_webui.retrieval.vector.dbs.weaviate import WeaviateClient + + return WeaviateClient() case _: raise ValueError(f"Unsupported vector type: {vector_type}") diff --git a/backend/open_webui/retrieval/vector/type.py b/backend/open_webui/retrieval/vector/type.py index 7e517c169c..292cad1e78 100644 --- a/backend/open_webui/retrieval/vector/type.py +++ b/backend/open_webui/retrieval/vector/type.py @@ -11,3 +11,4 @@ class VectorType(StrEnum): PGVECTOR = "pgvector" ORACLE23AI = "oracle23ai" S3VECTOR = "s3vector" + WEAVIATE = "weaviate" diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 45b4f1e692..1edf31fa9c 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -35,6 +35,7 @@ from pydantic import BaseModel from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.headers import include_user_info_headers from open_webui.config import ( WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, @@ -364,23 +365,17 @@ async def speech(request: Request, user=Depends(get_verified_user)): **(request.app.state.config.TTS_OPENAI_PARAMS or {}), } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", + } + if ENABLE_FORWARD_USER_INFO_HEADERS: + headers = include_user_info_headers(headers, user) + r = await session.post( url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", json=payload, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - }, + headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) @@ -570,7 +565,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) -def transcription_handler(request, file_path, metadata): +def transcription_handler(request, file_path, metadata, user=None): filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) id = filename.split(".")[0] @@ -621,11 +616,15 @@ def transcription_handler(request, file_path, metadata): if language: payload["language"] = language + headers = { + "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" + } + if user and ENABLE_FORWARD_USER_INFO_HEADERS: + headers = include_user_info_headers(headers, user) + r = requests.post( url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", - headers={ - "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" - }, + headers=headers, files={"file": (filename, open(file_path, "rb"))}, data=payload, ) @@ -1027,7 +1026,7 @@ def transcription_handler(request, file_path, metadata): ) -def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None): +def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None, user=None): log.info(f"transcribe: {file_path} {metadata}") if is_audio_conversion_required(file_path): @@ -1054,7 +1053,7 @@ def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None with ThreadPoolExecutor() as executor: # Submit tasks for each chunk_path futures = [ - executor.submit(transcription_handler, request, chunk_path, metadata) + executor.submit(transcription_handler, request, chunk_path, metadata, user) for chunk_path in chunk_paths ] # Gather results as they complete @@ -1189,7 +1188,7 @@ def transcription( if language: metadata = {"language": language} - result = transcribe(request, file_path, metadata) + result = transcribe(request, file_path, metadata, user) return { **result, diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 7de1175cc1..c30c1d48d4 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -45,6 +45,7 @@ from pydantic import BaseModel from open_webui.utils.misc import parse_duration, validate_email_format from open_webui.utils.auth import ( + validate_password, verify_password, decode_token, invalidate_token, @@ -181,10 +182,14 @@ async def update_password( ) if user: + try: + validate_password(form_data.password) + except Exception as e: + raise HTTPException(400, detail=str(e)) hashed = get_password_hash(form_data.new_password) return Auths.update_user_password_by_id(user.id, hashed) else: - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) + raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD) else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) @@ -627,16 +632,14 @@ async def signup(request: Request, response: Response, form_data: SignupForm): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: - role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE - - # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing. - if len(form_data.password.encode("utf-8")) > 72: - raise HTTPException( - status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PASSWORD_TOO_LONG, - ) + try: + validate_password(form_data.password) + except Exception as e: + raise HTTPException(400, detail=str(e)) hashed = get_password_hash(form_data.password) + + role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE user = Auths.insert_new_auth( form_data.email.lower(), hashed, @@ -691,7 +694,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm): if not has_users: # Disable signup after the first user is created request.app.state.config.ENABLE_SIGNUP = False - + + default_group_id = getattr(request.app.state.config, 'DEFAULT_GROUP_ID', "") + if default_group_id and default_group_id: + Groups.add_users_to_group(default_group_id, [user.id]) + return { "token": token, "token_type": "Bearer", @@ -805,6 +812,11 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: + try: + validate_password(form_data.password) + except Exception as e: + raise HTTPException(400, detail=str(e)) + hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( form_data.email.lower(), @@ -880,6 +892,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, "API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, + "DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, @@ -900,6 +913,7 @@ class AdminConfig(BaseModel): ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS: bool API_KEYS_ALLOWED_ENDPOINTS: str DEFAULT_USER_ROLE: str + DEFAULT_GROUP_ID: str JWT_EXPIRES_IN: str ENABLE_COMMUNITY_SHARING: bool ENABLE_MESSAGE_RATING: bool @@ -914,7 +928,7 @@ class AdminConfig(BaseModel): @router.post("/admin/config") async def update_admin_config( request: Request, form_data: AdminConfig, user=Depends(get_admin_user) -): +): request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS request.app.state.config.WEBUI_URL = form_data.WEBUI_URL request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP @@ -933,6 +947,8 @@ async def update_admin_config( if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE + request.app.state.config.DEFAULT_GROUP_ID = form_data.DEFAULT_GROUP_ID + pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$" # Check if the input string matches the pattern @@ -963,6 +979,7 @@ async def update_admin_config( "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, "API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, + "DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index c76a1f6915..3e5e14801c 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -7,6 +7,8 @@ from open_webui.models.feedbacks import ( FeedbackModel, FeedbackResponse, FeedbackForm, + FeedbackUserResponse, + FeedbackListResponse, Feedbacks, ) @@ -56,35 +58,10 @@ async def update_config( } -class UserResponse(BaseModel): - id: str - name: str - email: str - role: str = "pending" - - last_active_at: int # timestamp in epoch - updated_at: int # timestamp in epoch - created_at: int # timestamp in epoch - - -class FeedbackUserResponse(FeedbackResponse): - user: Optional[UserResponse] = None - - -@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse]) +@router.get("/feedbacks/all", response_model=list[FeedbackResponse]) async def get_all_feedbacks(user=Depends(get_admin_user)): feedbacks = Feedbacks.get_all_feedbacks() - - feedback_list = [] - for feedback in feedbacks: - user = Users.get_user_by_id(feedback.user_id) - feedback_list.append( - FeedbackUserResponse( - **feedback.model_dump(), - user=UserResponse(**user.model_dump()) if user else None, - ) - ) - return feedback_list + return feedbacks @router.delete("/feedbacks/all") @@ -111,6 +88,31 @@ async def delete_feedbacks(user=Depends(get_verified_user)): return success +PAGE_ITEM_COUNT = 30 + + +@router.get("/feedbacks/list", response_model=FeedbackListResponse) +async def get_feedbacks( + order_by: Optional[str] = None, + direction: Optional[str] = None, + page: Optional[int] = 1, + user=Depends(get_admin_user), +): + limit = PAGE_ITEM_COUNT + + page = max(1, page) + skip = (page - 1) * limit + + filter = {} + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + + result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit) + return result + + @router.post("/feedback", response_model=FeedbackModel) async def create_feedback( request: Request, diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 2a5c3e5bb1..54084941fe 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -102,7 +102,7 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us ) ): file_path = Storage.get_file(file_path) - result = transcribe(request, file_path, file_metadata) + result = transcribe(request, file_path, file_metadata, user) process_file( request, diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index 331c831153..2b531b462b 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -31,20 +31,32 @@ router = APIRouter() @router.get("/", response_model=list[GroupResponse]) -async def get_groups(user=Depends(get_verified_user)): +async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)): if user.role == "admin": groups = Groups.get_groups() else: groups = Groups.get_groups_by_member_id(user.id) - return [ - GroupResponse( - **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + group_list = [] + + for group in groups: + if share is not None: + # Check if the group has data and a config with share key + if ( + group.data + and "share" in group.data.get("config", {}) + and group.data["config"]["share"] != share + ): + continue + + group_list.append( + GroupResponse( + **group.model_dump(), + member_count=Groups.get_group_member_count_by_id(group.id), + ) ) - for group in groups - if group - ] + + return group_list ############################ diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 5f035695a9..c4e67ae9ea 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -126,6 +126,7 @@ class ImagesConfig(BaseModel): IMAGES_GEMINI_API_KEY: str IMAGES_GEMINI_ENDPOINT_METHOD: str + ENABLE_IMAGE_EDIT: bool IMAGE_EDIT_ENGINE: str IMAGE_EDIT_MODEL: str IMAGE_EDIT_SIZE: Optional[str] @@ -164,6 +165,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, + "ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT, "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, @@ -253,6 +255,7 @@ async def update_config( ) # Edit Image + request.app.state.config.ENABLE_IMAGE_EDIT = form_data.ENABLE_IMAGE_EDIT request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE @@ -308,6 +311,7 @@ async def update_config( "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, + "ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT, "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 183bd5f55b..a689d26e98 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -253,6 +253,7 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)): ) except Exception as e: pass + return FileResponse(f"{STATIC_DIR}/favicon.png") else: return FileResponse(f"{STATIC_DIR}/favicon.png") @@ -320,7 +321,7 @@ async def update_model_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - model = Models.update_model_by_id(form_data.id, form_data) + model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump())) return model diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index d615a28634..f53b0e2749 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -36,7 +36,12 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR -from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user +from open_webui.utils.auth import ( + get_admin_user, + get_password_hash, + get_verified_user, + validate_password, +) from open_webui.utils.access_control import get_permissions, has_permission @@ -178,10 +183,15 @@ class WorkspacePermissions(BaseModel): class SharingPermissions(BaseModel): - public_models: bool = True - public_knowledge: bool = True - public_prompts: bool = True + models: bool = False + public_models: bool = False + knowledge: bool = False + public_knowledge: bool = False + prompts: bool = False + public_prompts: bool = False + tools: bool = False public_tools: bool = True + notes: bool = False public_notes: bool = True @@ -497,8 +507,12 @@ async def update_user_by_id( ) if form_data.password: + try: + validate_password(form_data.password) + except Exception as e: + raise HTTPException(400, detail=str(e)) + hashed = get_password_hash(form_data.password) - log.debug(f"hashed: {hashed}") Auths.update_user_password_by_id(user_id, hashed) Auths.update_email_by_id(user_id, form_data.email.lower()) diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 8689cd99c2..61b8fb13a4 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -28,8 +28,10 @@ from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( + ENABLE_PASSWORD_VALIDATION, OFFLINE_MODE, LICENSE_BLOB, + PASSWORD_VALIDATION_REGEX_PATTERN, REDIS_KEY_PREFIX, pk, WEBUI_SECRET_KEY, @@ -162,6 +164,20 @@ def get_password_hash(password: str) -> str: return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") +def validate_password(password: str) -> bool: + # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing. + if len(password.encode("utf-8")) > 72: + raise Exception( + ERROR_MESSAGES.PASSWORD_TOO_LONG, + ) + + if ENABLE_PASSWORD_VALIDATION: + if not PASSWORD_VALIDATION_REGEX_PATTERN.match(password): + raise Exception(ERROR_MESSAGES.INVALID_PASSWORD()) + + return True + + def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash""" return ( diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 7653536e34..5095bb418b 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -791,42 +791,13 @@ async def chat_image_generation_handler( input_images = get_last_images(message_list) system_message_content = "" - if len(input_images) == 0: - # Create image(s) - if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: - try: - res = await generate_image_prompt( - request, - { - "model": form_data["model"], - "messages": form_data["messages"], - }, - user, - ) - - response = res["choices"][0]["message"]["content"] - - try: - bracket_start = response.find("{") - bracket_end = response.rfind("}") + 1 - - if bracket_start == -1 or bracket_end == -1: - raise Exception("No JSON object found in the response") - - response = response[bracket_start:bracket_end] - response = json.loads(response) - prompt = response.get("prompt", []) - except Exception as e: - prompt = user_message - - except Exception as e: - log.exception(e) - prompt = user_message + if len(input_images) > 0 and request.app.state.config.ENABLE_IMAGE_EDIT: + # Edit image(s) try: - images = await image_generations( + images = await image_edits( request=request, - form_data=CreateImageForm(**{"prompt": prompt}), + form_data=EditImageForm(**{"prompt": prompt, "image": input_images}), user=user, ) @@ -874,12 +845,43 @@ async def chat_image_generation_handler( ) system_message_content = f"Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}" + else: - # Edit image(s) + # Create image(s) + if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: + try: + res = await generate_image_prompt( + request, + { + "model": form_data["model"], + "messages": form_data["messages"], + }, + user, + ) + + response = res["choices"][0]["message"]["content"] + + try: + bracket_start = response.find("{") + bracket_end = response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + response = response[bracket_start:bracket_end] + response = json.loads(response) + prompt = response.get("prompt", []) + except Exception as e: + prompt = user_message + + except Exception as e: + log.exception(e) + prompt = user_message + try: - images = await image_edits( + images = await image_generations( request=request, - form_data=EditImageForm(**{"prompt": prompt, "image": input_images}), + form_data=CreateImageForm(**{"prompt": prompt}), user=user, ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 8d6758e456..9936eeaa95 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -49,6 +49,7 @@ langchain-community==0.3.29 fake-useragent==2.2.0 chromadb==1.1.0 +weaviate-client==4.17.0 opensearch-py==2.8.0 transformers diff --git a/pyproject.toml b/pyproject.toml index e7822eee71..2d3f6a2835 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ all = [ "elasticsearch==9.1.0", "qdrant-client==1.14.3", + "weaviate-client==4.17.0", "pymilvus==2.6.2", "pinecone==6.0.2", "oracledb==3.2.0", diff --git a/src/lib/apis/evaluations/index.ts b/src/lib/apis/evaluations/index.ts index 96a689fcb1..1f48c7bfbf 100644 --- a/src/lib/apis/evaluations/index.ts +++ b/src/lib/apis/evaluations/index.ts @@ -93,6 +93,45 @@ export const getAllFeedbacks = async (token: string = '') => { return res; }; +export const getFeedbackItems = async (token: string = '', orderBy, direction, page) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (orderBy) searchParams.append('order_by', orderBy); + if (direction) searchParams.append('direction', direction); + if (page) searchParams.append('page', page.toString()); + + const res = await fetch( + `${WEBUI_API_BASE_URL}/evaluations/feedbacks/list?${searchParams.toString()}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const exportAllFeedbacks = async (token: string = '') => { let error = null; diff --git a/src/lib/apis/groups/index.ts b/src/lib/apis/groups/index.ts index 51b49bf4d9..a74c61b83d 100644 --- a/src/lib/apis/groups/index.ts +++ b/src/lib/apis/groups/index.ts @@ -31,10 +31,15 @@ export const createNewGroup = async (token: string, group: object) => { return res; }; -export const getGroups = async (token: string = '') => { +export const getGroups = async (token: string = '', share?: boolean) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/groups/`, { + const searchParams = new URLSearchParams(); + if (share !== undefined) { + searchParams.append('share', String(share)); + } + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 126c59ad2f..e865e9ba0e 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1425,7 +1425,7 @@ export const getVersion = async (token: string) => { throw error; } - return res?.version ?? null; + return res; }; export const getVersionUpdates = async (token: string) => { diff --git a/src/lib/components/admin/Evaluations.svelte b/src/lib/components/admin/Evaluations.svelte index d29dee746c..e2849bd98b 100644 --- a/src/lib/components/admin/Evaluations.svelte +++ b/src/lib/components/admin/Evaluations.svelte @@ -33,7 +33,9 @@ let feedbacks = []; onMount(async () => { + // TODO: feedbacks elo rating calculation should be done in the backend; remove below line later feedbacks = await getAllFeedbacks(localStorage.token); + loaded = true; const containerElement = document.getElementById('users-tabs-container'); @@ -117,7 +119,7 @@ {#if selectedTab === 'leaderboard'} {:else if selectedTab === 'feedbacks'} - + {/if} diff --git a/src/lib/components/admin/Evaluations/Feedbacks.svelte b/src/lib/components/admin/Evaluations/Feedbacks.svelte index 62304088ed..c47524eef4 100644 --- a/src/lib/components/admin/Evaluations/Feedbacks.svelte +++ b/src/lib/components/admin/Evaluations/Feedbacks.svelte @@ -10,7 +10,7 @@ import { onMount, getContext } from 'svelte'; const i18n = getContext('i18n'); - import { deleteFeedbackById, exportAllFeedbacks, getAllFeedbacks } from '$lib/apis/evaluations'; + import { deleteFeedbackById, exportAllFeedbacks, getFeedbackItems } from '$lib/apis/evaluations'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import Download from '$lib/components/icons/Download.svelte'; @@ -23,78 +23,25 @@ import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; - import { WEBUI_BASE_URL } from '$lib/constants'; + import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { config } from '$lib/stores'; - - export let feedbacks = []; + import Spinner from '$lib/components/common/Spinner.svelte'; let page = 1; - $: paginatedFeedbacks = sortedFeedbacks.slice((page - 1) * 10, page * 10); + let items = null; + let total = null; let orderBy: string = 'updated_at'; let direction: 'asc' | 'desc' = 'desc'; - type Feedback = { - id: string; - data: { - rating: number; - model_id: string; - sibling_model_ids: string[] | null; - reason: string; - comment: string; - tags: string[]; - }; - user: { - name: string; - profile_image_url: string; - }; - updated_at: number; - }; - - type ModelStats = { - rating: number; - won: number; - lost: number; - }; - - function setSortKey(key: string) { + const setSortKey = (key) => { if (orderBy === key) { direction = direction === 'asc' ? 'desc' : 'asc'; } else { orderBy = key; - if (key === 'user' || key === 'model_id') { - direction = 'asc'; - } else { - direction = 'desc'; - } + direction = 'asc'; } - page = 1; - } - - $: sortedFeedbacks = [...feedbacks].sort((a, b) => { - let aVal, bVal; - - switch (orderBy) { - case 'user': - aVal = a.user?.name || ''; - bVal = b.user?.name || ''; - return direction === 'asc' ? aVal.localeCompare(bVal) : bVal.localeCompare(aVal); - case 'model_id': - aVal = a.data.model_id || ''; - bVal = b.data.model_id || ''; - return direction === 'asc' ? aVal.localeCompare(bVal) : bVal.localeCompare(aVal); - case 'rating': - aVal = a.data.rating; - bVal = b.data.rating; - return direction === 'asc' ? aVal - bVal : bVal - aVal; - case 'updated_at': - aVal = a.updated_at; - bVal = b.updated_at; - return direction === 'asc' ? aVal - bVal : bVal - aVal; - default: - return 0; - } - }); + }; let showFeedbackModal = false; let selectedFeedback = null; @@ -115,13 +62,41 @@ // ////////////////////// + const getFeedbacks = async () => { + try { + const res = await getFeedbackItems(localStorage.token, orderBy, direction, page).catch( + (error) => { + toast.error(`${error}`); + return null; + } + ); + + if (res) { + items = res.items; + total = res.total; + } + } catch (err) { + console.error(err); + } + }; + + $: if (page) { + getFeedbacks(); + } + + $: if (orderBy && direction) { + getFeedbacks(); + } + const deleteFeedbackHandler = async (feedbackId: string) => { const response = await deleteFeedbackById(localStorage.token, feedbackId).catch((err) => { toast.error(err); return null; }); if (response) { - feedbacks = feedbacks.filter((f) => f.id !== feedbackId); + toast.success($i18n.t('Feedback deleted successfully')); + page = 1; + getFeedbacks(); } }; @@ -169,256 +144,266 @@ -
-
- {$i18n.t('Feedback History')} +{#if items === null || total === null} +
+ +
+{:else} +
+
+
+ {$i18n.t('Feedback History')} +
-
+
+ {total} +
+
- {feedbacks.length} + {#if total > 0} +
+ + + +
+ {/if}
- {#if feedbacks.length > 0} -
- - - -
- {/if} -
- -
- {#if (feedbacks ?? []).length === 0} -
- {$i18n.t('No feedbacks found')} -
- {:else} - - - - + {/each} + +
setSortKey('user')} - > -
- {$i18n.t('User')} - {#if orderBy === 'user'} - - {#if direction === 'asc'} +
+ {#if (items ?? []).length === 0} +
+ {$i18n.t('No feedbacks found')} +
+ {:else} + + + + - - - - - - - - - - - - {#each paginatedFeedbacks as feedback (feedback.id)} - openFeedbackModal(feedback)} - > - + - - {#if feedback?.data?.rating} - + + + + + + + + {#each items as feedback (feedback.id)} + openFeedbackModal(feedback)} + > + - {/if} - + + + {#if feedback?.data?.rating} + + {/if} + + + + - - {/each} - -
setSortKey('user')} + > +
+ {$i18n.t('User')} + {#if orderBy === 'user'} + + {#if direction === 'asc'} + + {:else} + + {/if} + + {:else} + - {:else} - - {/if} -
-
setSortKey('model_id')} - > -
- {$i18n.t('Models')} - {#if orderBy === 'model_id'} - - {#if direction === 'asc'} - - {:else} - - {/if} - - {:else} - - {/if} -
-
setSortKey('rating')} - > -
- {$i18n.t('Result')} - {#if orderBy === 'rating'} - - {#if direction === 'asc'} - - {:else} - - {/if} - - {:else} - - {/if} -
-
setSortKey('updated_at')} - > -
- {$i18n.t('Updated At')} - {#if orderBy === 'updated_at'} - - {#if direction === 'asc'} - - {:else} - - {/if} - - {:else} - - {/if} -
-
-
- -
- {feedback?.user?.name} -
-
+ + {/if}
-
-
-
- {#if feedback.data?.sibling_model_ids} -
- {feedback.data?.model_id} -
- - -
- {#if feedback.data.sibling_model_ids.length > 2} - - {feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t( - 'and {{COUNT}} more', - { COUNT: feedback.data.sibling_model_ids.length - 2 } - )} - {:else} - {feedback.data.sibling_model_ids.join(', ')} - {/if} -
-
- {:else} -
- {feedback.data?.model_id} -
- {/if} -
+
setSortKey('model_id')} + > +
+ {$i18n.t('Models')} + {#if orderBy === 'model_id'} + + {#if direction === 'asc'} + + {:else} + + {/if} + + {:else} + + {/if}
- +
-
- {#if feedback?.data?.rating.toString() === '1'} - - {:else if feedback?.data?.rating.toString() === '0'} - - {:else if feedback?.data?.rating.toString() === '-1'} - - {/if} +
setSortKey('rating')} + > +
+ {$i18n.t('Result')} + {#if orderBy === 'rating'} + + {#if direction === 'asc'} + + {:else} + + {/if} + + {:else} + + {/if} +
+
setSortKey('updated_at')} + > +
+ {$i18n.t('Updated At')} + {#if orderBy === 'updated_at'} + + {#if direction === 'asc'} + + {:else} + + {/if} + + {:else} + + {/if} +
+
+
+ +
+ {feedback?.user?.name} +
+
- {dayjs(feedback.updated_at * 1000).fromNow()} - +
+
+ {#if feedback.data?.sibling_model_ids} +
+ {feedback.data?.model_id} +
-
e.stopPropagation()}> - { - deleteFeedbackHandler(feedback.id); - }} - > - +
+ {#if feedback?.data?.rating.toString() === '1'} + + {:else if feedback?.data?.rating.toString() === '0'} + + {:else if feedback?.data?.rating.toString() === '-1'} + + {/if} +
+
+ {dayjs(feedback.updated_at * 1000).fromNow()} + e.stopPropagation()}> + { + deleteFeedbackHandler(feedback.id); + }} > - - - -
- {/if} -
- -{#if feedbacks.length > 0 && $config?.features?.enable_community_sharing} -
-
- {$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')} -
- -
- - - -
+ + + +
+ {/if}
-{/if} -{#if feedbacks.length > 10} - + {#if total > 0 && $config?.features?.enable_community_sharing} +
+
+ {$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')} +
+ +
+ + + +
+
+ {/if} + + {#if total > 30} + + {/if} {/if} diff --git a/src/lib/components/admin/Evaluations/Leaderboard.svelte b/src/lib/components/admin/Evaluations/Leaderboard.svelte index d66fbf7821..948551cde2 100644 --- a/src/lib/components/admin/Evaluations/Leaderboard.svelte +++ b/src/lib/components/admin/Evaluations/Leaderboard.svelte @@ -10,7 +10,7 @@ import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; - import { WEBUI_BASE_URL } from '$lib/constants'; + import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; const i18n = getContext('i18n'); @@ -339,16 +339,14 @@
-
-
+
+
{$i18n.t('Leaderboard')}
-
- - {rankedModels.length} +
+ {rankedModels.length} +
@@ -517,7 +515,7 @@
{model.name} diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index aae75dfbd8..83369ac993 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -757,18 +757,18 @@ {/if} {/if} -
-
- +
+
+
{$i18n.t('Parameters')} - -
-
-