Merge pull request #19030 from open-webui/dev

0.6.37
This commit is contained in:
Tim Baek 2025-11-23 22:10:05 -05:00 committed by GitHub
commit fe6783c166
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
250 changed files with 12186 additions and 6451 deletions

View file

@ -5,6 +5,96 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.6.37] - 2025-11-24
### Added
- 🔐 Granular sharing permissions are now available with two-tiered control separating group sharing from public sharing, allowing administrators to independently configure whether users can share workspace items with groups or make them publicly accessible, with separate permission toggles for models, knowledge bases, prompts, tools, and notes, configurable via "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING", "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", and corresponding environment variables for other workspace item types, while groups can now be configured to opt-out of sharing via the "Allow Group Sharing" setting. [Commit](https://github.com/open-webui/open-webui/commit/7be750bcbb40da91912a0a66b7ab791effdcc3b6), [Commit](https://github.com/open-webui/open-webui/commit/f69e37a8507d6d57382d6670641b367f3127f90a)
- 🔐 Password policy enforcement is now available with configurable validation rules, allowing administrators to require specific password complexity requirements via "ENABLE_PASSWORD_VALIDATION" and "PASSWORD_VALIDATION_REGEX_PATTERN" environment variables, with default pattern requiring minimum 8 characters including uppercase, lowercase, digit, and special character. [#17794](https://github.com/open-webui/open-webui/pull/17794)
- 🔐 Granular import and export permissions are now available for workspace items, introducing six separate permission toggles for models, prompts, and tools that are disabled by default for enhanced security. [#19242](https://github.com/open-webui/open-webui/pull/19242)
- 👥 Default group assignment is now available for new users, allowing administrators to automatically assign newly registered users to a specified group for streamlined access control to models, prompts, and tools, particularly useful for organizations with group-based model access policies. [#19325](https://github.com/open-webui/open-webui/pull/19325), [#17842](https://github.com/open-webui/open-webui/issues/17842)
- 🔒 Password-based authentication can now be fully disabled via "ENABLE_PASSWORD_AUTH" environment variable, enforcing SSO-only authentication and preventing password login fallback when SSO is configured. [#19113](https://github.com/open-webui/open-webui/pull/19113)
- 🖼️ Large stream chunk handling was implemented to support models that generate images directly in their output responses, with configurable buffer size via "CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE" environment variable, resolving compatibility issues with models like Gemini 2.5 Flash Image. [#18884](https://github.com/open-webui/open-webui/pull/18884), [#17626](https://github.com/open-webui/open-webui/issues/17626)
- 🖼️ Streaming response middleware now handles images in delta updates with automatic base64 conversion, enabling proper display of images from models using the "choices[0].delta.images.image_url" format such as Gemini 2.5 Flash Image Preview on OpenRouter. [#19073](https://github.com/open-webui/open-webui/pull/19073), [#19019](https://github.com/open-webui/open-webui/issues/19019)
- 📈 Model list API performance was optimized by pre-fetching user group memberships and removing profile image URLs from response payloads, significantly reducing both database queries and payload size for instances with large model lists, with profile images now served dynamically via dedicated endpoints. [#19097](https://github.com/open-webui/open-webui/pull/19097), [#18950](https://github.com/open-webui/open-webui/issues/18950)
- ⏩ Batch file processing performance was improved by reducing database queries by 67% while ensuring data consistency between vector and relational databases. [#18953](https://github.com/open-webui/open-webui/pull/18953)
- 🚀 Chat import performance was dramatically improved by replacing individual per-chat API requests with a bulk import endpoint, reducing import time by up to 95% for large chat collections and providing user feedback via toast notifications displaying the number of successfully imported chats. [#17861](https://github.com/open-webui/open-webui/pull/17861)
- ⚡ Socket event broadcasting performance was optimized by implementing user-specific rooms, significantly reducing server overhead particularly for users with multiple concurrent sessions. [#18996](https://github.com/open-webui/open-webui/pull/18996)
- 🗄️ Weaviate is now supported as a vector database option, providing an additional choice for RAG document storage alongside existing ChromaDB, Milvus, Qdrant, and OpenSearch integrations. [#14747](https://github.com/open-webui/open-webui/pull/14747)
- 🗄️ PostgreSQL pgvector now supports HNSW index types and large dimensional embeddings exceeding 2000 dimensions through automatic halfvec type selection, with configurable index methods via "PGVECTOR_INDEX_METHOD", "PGVECTOR_HNSW_M", "PGVECTOR_HNSW_EF_CONSTRUCTION", and "PGVECTOR_IVFFLAT_LISTS" environment variables. [#19158](https://github.com/open-webui/open-webui/pull/19158), [#16890](https://github.com/open-webui/open-webui/issues/16890)
- 🔍 Azure AI Search is now supported as a web search provider, enabling integration with Azure's cognitive search services via "AZURE_AI_SEARCH_API_KEY", "AZURE_AI_SEARCH_ENDPOINT", and "AZURE_AI_SEARCH_INDEX_NAME" configuration. [#19104](https://github.com/open-webui/open-webui/pull/19104)
- ⚡ External embedding generation now processes API requests in parallel instead of sequential batches, reducing document processing time by 10-50x when using OpenAI, Azure OpenAI, or Ollama embedding providers, with large PDFs now processing in seconds instead of minutes. [#19296](https://github.com/open-webui/open-webui/pull/19296)
- 💨 Base64 image conversion is now available for markdown content in chat responses, automatically uploading embedded images exceeding 1KB and replacing them with file URLs to reduce payload size and resource consumption, configurable via "REPLACE_IMAGE_URLS_IN_CHAT_RESPONSE" environment variable. [#19076](https://github.com/open-webui/open-webui/pull/19076)
- 🎨 OpenAI image generation now supports additional API parameters including quality settings for GPT Image 1, configurable via "IMAGES_OPENAI_API_PARAMS" environment variable or through the admin interface, enabling cost-effective image generation with low, medium, or high quality options. [#19228](https://github.com/open-webui/open-webui/issues/19228)
- 🖼️ Image editing can now be independently enabled or disabled via admin settings, allowing administrators to control whether sequential image prompts trigger image editing or new image generation, configurable via "ENABLE_IMAGE_EDIT" environment variable. [#19284](https://github.com/open-webui/open-webui/issues/19284)
- 🔐 SSRF protection was implemented with a configurable URL blocklist that prevents access to cloud metadata endpoints and private networks, with default protections for AWS, Google Cloud, Azure, and Alibaba Cloud metadata services, customizable via "WEB_FETCH_FILTER_LIST" environment variable. [#19201](https://github.com/open-webui/open-webui/pull/19201)
- ⚡ Workspace models page now supports server-side pagination dramatically improving load times and usability for instances with large numbers of workspace models.
- 🔍 Hybrid search now indexes file metadata including filenames, titles, headings, sources, and snippets alongside document content, enabling keyword queries to surface documents where search terms appear only in metadata, configurable via "ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS" environment variable. [#19095](https://github.com/open-webui/open-webui/pull/19095)
- 📂 Knowledge base upload page now supports folder drag-and-drop with recursive directory handling, enabling batch uploads of entire directory structures instead of requiring individual file selection. [#19320](https://github.com/open-webui/open-webui/pull/19320)
- 🤖 Model cloning is now available in admin settings, allowing administrators to quickly create workspace models based on existing base models through a "Clone" option in the model dropdown menu. [#17937](https://github.com/open-webui/open-webui/pull/17937)
- 🎨 UI scale adjustment is now available in interface settings, allowing users to increase the size of the entire interface from 1.0x to 1.5x for improved accessibility and readability, particularly beneficial for users with visual impairments. [#19186](https://github.com/open-webui/open-webui/pull/19186)
- 📌 Default pinned models can now be configured by administrators for all new users, mirroring the behavior of default models where admin-configured defaults apply only to users who haven't customized their pinned models, configurable via "DEFAULT_PINNED_MODELS" environment variable. [#19273](https://github.com/open-webui/open-webui/pull/19273)
- 🎙️ Text-to-Speech and Speech-to-Text services now receive user information headers when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled, allowing external TTS and STT providers to implement user-specific personalization, rate limiting, and usage tracking. [#19323](https://github.com/open-webui/open-webui/pull/19323), [#19312](https://github.com/open-webui/open-webui/issues/19312)
- 🎙️ Voice mode now supports custom system prompts via "VOICE_MODE_PROMPT_TEMPLATE" configuration, allowing administrators to control response style and behavior for voice interactions. [#18607](https://github.com/open-webui/open-webui/pull/18607)
- 🔧 WebSocket and Redis configuration options are now available including debug logging controls, custom ping timeout and interval settings, and arbitrary Redis connection options via "WEBSOCKET_SERVER_LOGGING", "WEBSOCKET_SERVER_ENGINEIO_LOGGING", "WEBSOCKET_SERVER_PING_TIMEOUT", "WEBSOCKET_SERVER_PING_INTERVAL", and "WEBSOCKET_REDIS_OPTIONS" environment variables. [#19091](https://github.com/open-webui/open-webui/pull/19091)
- 🔧 MCP OAuth dynamic client registration now automatically detects and uses the appropriate token endpoint authentication method from server-supported options, enabling compatibility with OAuth servers that only support "client_secret_basic" instead of "client_secret_post". [#19193](https://github.com/open-webui/open-webui/issues/19193)
- 🔧 Custom headers can now be configured for remote MCP and OpenAPI tool server connections, enabling integration with services that require additional authentication headers. [#18918](https://github.com/open-webui/open-webui/issues/18918)
- 🔍 Perplexity Search now supports custom API endpoints via "PERPLEXITY_SEARCH_API_URL" configuration and automatically forwards user information headers to enable personalized search experiences. [#19147](https://github.com/open-webui/open-webui/pull/19147)
- 🔍 User information headers can now be optionally forwarded to external web search engines when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled. [#19043](https://github.com/open-webui/open-webui/pull/19043)
- 📊 Daily active user metric is now available for monitoring, tracking unique users active since midnight UTC via the "webui.users.active.today" Prometheus gauge. [#19236](https://github.com/open-webui/open-webui/pull/19236), [#19234](https://github.com/open-webui/open-webui/issues/19234)
- 📊 Audit log file path is now configurable via "AUDIT_LOGS_FILE_PATH" environment variable, enabling storage in separate volumes or custom locations. [#19173](https://github.com/open-webui/open-webui/pull/19173)
- 🎨 Sidebar collapse states for model lists and group information are now persistent across page refreshes, remembering user preferences through browser-based storage. [#19159](https://github.com/open-webui/open-webui/issues/19159)
- 🎨 Background image display was enhanced with semi-transparent overlays for navbar and sidebar, creating a seamless and visually cohesive design across the entire interface. [#19157](https://github.com/open-webui/open-webui/issues/19157)
- 📋 Tables in chat messages now include a copy button that appears on hover, enabling quick copying of table content alongside the existing CSV export functionality. [#19162](https://github.com/open-webui/open-webui/issues/19162)
- 📝 Notes can now be created directly via the "/notes/new" URL endpoint with optional title and content query parameters, enabling faster note creation through bookmarks and shortcuts. [#19195](https://github.com/open-webui/open-webui/issues/19195)
- 🏷️ Tag suggestions are now context-aware, displaying only relevant tags when creating or editing models versus chat conversations, preventing confusion between model and chat tags. [#19135](https://github.com/open-webui/open-webui/issues/19135)
- ✍️ Prompt autocompletion is now available independently of the rich text input setting, improving accessibility to the feature. [#19150](https://github.com/open-webui/open-webui/issues/19150)
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
- 🌐 Translations for Simplified Chinese, Traditional Chinese, Portuguese (Brazil), Catalan, Spanish (Spain), Finnish, Irish, Farsi, Swedish, Danish, German, Korean, and Thai were improved and expanded.
### Fixed
- 🤖 Model update functionality now works correctly, resolving a database parameter binding error that prevented saving changes to model configurations via the Save & Update button. [#19335](https://github.com/open-webui/open-webui/issues/19335)
- 🖼️ Multiple input images for image editing and generation are now correctly passed as an array using the "image[]" parameter syntax, enabling proper multi-image reference functionality with models like GPT Image 1. [#19339](https://github.com/open-webui/open-webui/issues/19339)
- 📱 PWA installations on iOS now properly refresh after server container restarts, resolving freezing issues by automatically unregistering service workers when version or deployment changes are detected. [#19316](https://github.com/open-webui/open-webui/pull/19316)
- 🗄️ S3 Vectors collection detection now correctly handles buckets with more than 2000 indexes by using direct index lookup instead of paginated list scanning, improving performance by approximately 8x and enabling RAG queries to work reliably at scale. [#19238](https://github.com/open-webui/open-webui/pull/19238), [#19233](https://github.com/open-webui/open-webui/issues/19233)
- 📈 Feedback retrieval performance was optimized by eliminating N+1 query patterns through database joins, adding server-side pagination and sorting, significantly reducing database load for instances with large feedback datasets. [#17976](https://github.com/open-webui/open-webui/pull/17976)
- 🔍 Chat search now works correctly with PostgreSQL when chat data contains null bytes, with comprehensive sanitization preventing null bytes during data writes, cleaning existing data on read, and stripping null bytes during search queries to ensure reliable search functionality. [#15616](https://github.com/open-webui/open-webui/issues/15616)
- 🔍 Hybrid search with reranking now correctly handles attribute validation, preventing errors when collection results lack expected structure. [#19025](https://github.com/open-webui/open-webui/pull/19025), [#17046](https://github.com/open-webui/open-webui/issues/17046)
- 🔎 Reranking functionality now works correctly after recent refactoring, resolving crashes caused by incorrect function argument handling. [#19270](https://github.com/open-webui/open-webui/pull/19270)
- 🤖 Azure OpenAI models now support the "reasoning_effort" parameter, enabling proper configuration of reasoning capabilities for models like GPT-5.1 which default to no reasoning without this setting. [#19290](https://github.com/open-webui/open-webui/issues/19290)
- 🤖 Models with very long IDs can now be deleted correctly, resolving URL length limitations that previously prevented management operations on such models. [#18230](https://github.com/open-webui/open-webui/pull/18230)
- 🤖 Model-level streaming settings now correctly apply to API requests, ensuring "Stream Chat Response" toggle properly controls the streaming parameter. [#19154](https://github.com/open-webui/open-webui/issues/19154)
- 🖼️ Image editing configuration now correctly preserves independent OpenAI API endpoints and keys, preventing them from being overwritten by image generation settings. [#19003](https://github.com/open-webui/open-webui/issues/19003)
- 🎨 Gemini image edit settings now display correctly in the admin panel, fixing an incorrect configuration key reference that prevented proper rendering of edit options. [#19200](https://github.com/open-webui/open-webui/pull/19200)
- 🖌️ Image generation settings menu now loads correctly, resolving validation errors with AUTOMATIC1111 API authentication parameters. [#19187](https://github.com/open-webui/open-webui/issues/19187), [#19246](https://github.com/open-webui/open-webui/issues/19246)
- 📅 Date formatting in chat search and admin user chat search now correctly respects the "DEFAULT_LOCALE" environment variable, displaying dates according to the configured locale instead of always using MM/DD/YYYY format. [#19305](https://github.com/open-webui/open-webui/pull/19305), [#19020](https://github.com/open-webui/open-webui/issues/19020)
- 📝 RAG template query placeholder escaping logic was corrected to prevent unintended replacements of context values when query placeholders appear in retrieved content. [#19102](https://github.com/open-webui/open-webui/pull/19102), [#19101](https://github.com/open-webui/open-webui/issues/19101)
- 📄 RAG template prompt duplication was eliminated by removing redundant user query section from the default template. [#19099](https://github.com/open-webui/open-webui/pull/19099), [#19098](https://github.com/open-webui/open-webui/issues/19098)
- 📋 MinerU local mode configuration no longer incorrectly requires an API key, allowing proper use of local content extraction without external API credentials. [#19258](https://github.com/open-webui/open-webui/issues/19258)
- 📊 Excel file uploads now work correctly with the addition of the missing msoffcrypto-tool dependency, resolving import errors introduced by the unstructured package upgrade. [#19153](https://github.com/open-webui/open-webui/issues/19153)
- 📑 Docling parameters now properly handle JSON serialization, preventing exceptions and ensuring configuration changes are saved correctly. [#19072](https://github.com/open-webui/open-webui/pull/19072)
- 🛠️ UserValves configuration now correctly isolates settings per tool, preventing configuration contamination when multiple tools with UserValves are used simultaneously. [#19185](https://github.com/open-webui/open-webui/pull/19185), [#15569](https://github.com/open-webui/open-webui/issues/15569)
- 🔧 Tool selection prompt now correctly handles user messages without duplication, removing redundant query prefixes and improving prompt clarity. [#19122](https://github.com/open-webui/open-webui/pull/19122), [#19121](https://github.com/open-webui/open-webui/issues/19121)
- 📝 Notes chat feature now correctly submits messages to the completions endpoint, resolving errors that prevented AI model interactions. [#19079](https://github.com/open-webui/open-webui/pull/19079)
- 📝 Note PDF downloads now sanitize HTML content using DOMPurify before rendering, preventing potential DOM-based XSS attacks from malicious content in notes. [Commit](https://github.com/open-webui/open-webui/commit/03cc6ce8eb5c055115406e2304fbf7e3338b8dce)
- 📁 Archived chats now have their folder associations automatically removed to prevent unintended deletion when their previous folder is deleted. [#14578](https://github.com/open-webui/open-webui/issues/14578)
- 🔐 ElevenLabs API key is now properly obfuscated in the admin settings page, preventing plain text exposure of sensitive credentials. [#19262](https://github.com/open-webui/open-webui/pull/19262), [#19260](https://github.com/open-webui/open-webui/issues/19260)
- 🔧 MCP OAuth server metadata discovery now follows the correct specification order, ensuring proper authentication flow compliance. [#19244](https://github.com/open-webui/open-webui/pull/19244)
- 🔒 API key endpoint restrictions now properly enforce access controls for all endpoints including SCIM, preventing unintended access when "API_KEY_ALLOWED_ENDPOINTS" is configured. [#19168](https://github.com/open-webui/open-webui/issues/19168)
- 🔓 OAuth role claim parsing now supports both flat and nested claim structures, enabling compatibility with OAuth providers that deliver claims as direct properties on the user object rather than nested structures. [#19286](https://github.com/open-webui/open-webui/pull/19286)
- 🔑 OAuth MCP server verification now correctly extracts the access token value for authorization headers instead of sending the entire token dictionary. [#19149](https://github.com/open-webui/open-webui/pull/19149), [#19148](https://github.com/open-webui/open-webui/issues/19148)
- ⚙️ OAuth dynamic client registration now correctly converts empty strings to None for optional fields, preventing validation failures in MCP package integration. [#19144](https://github.com/open-webui/open-webui/pull/19144), [#19129](https://github.com/open-webui/open-webui/issues/19129)
- 🔐 OIDC authentication now correctly passes client credentials in access token requests, ensuring compatibility with providers that require these parameters per RFC 6749. [#19132](https://github.com/open-webui/open-webui/pull/19132), [#19131](https://github.com/open-webui/open-webui/issues/19131)
- 🔗 OAuth client creation now respects configured token endpoint authentication methods instead of defaulting to basic authentication, preventing failures with servers that don't support basic auth. [#19165](https://github.com/open-webui/open-webui/pull/19165)
- 📋 Text copied from chat responses in Chrome now pastes without background formatting, improving readability when pasting into word processors. [#19083](https://github.com/open-webui/open-webui/issues/19083)
### Changed
- 🗄️ Group membership data storage was refactored from JSON arrays to a dedicated relational database table, significantly improving query performance and scalability for instances with large numbers of users and groups, while API responses now return member counts instead of full user ID arrays. [#19239](https://github.com/open-webui/open-webui/pull/19239)
- 📄 MinerU parameter handling was refactored to pass parameters directly to the API, improving flexibility and fixing VLM backend configuration. [#19105](https://github.com/open-webui/open-webui/pull/19105), [#18446](https://github.com/open-webui/open-webui/discussions/18446)
- 🔐 API key creation is now controlled by granular user and group permissions, with the "ENABLE_API_KEY" environment variable renamed to "ENABLE_API_KEYS" and disabled by default, requiring explicit configuration at both the global and user permission levels, while related environment variables "ENABLE_API_KEY_ENDPOINT_RESTRICTIONS" and "API_KEY_ALLOWED_ENDPOINTS" were renamed to "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS" and "API_KEYS_ALLOWED_ENDPOINTS" respectively. [#18336](https://github.com/open-webui/open-webui/pull/18336)
## [0.6.36] - 2025-11-07 ## [0.6.36] - 2025-11-07
### Added ### Added

View file

@ -31,32 +31,44 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users. - 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
- 🔄 **SCIM 2.0 Support**: Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices. - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface. - 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment. - 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features using multiple Speech-to-Text providers (Local Whisper, OpenAI, Deepgram, Azure) and Text-to-Speech engines (Azure, ElevenLabs, OpenAI, Transformers, WebAPI), allowing for dynamic and interactive chat environments.
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration. - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs. - 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. - 💾 **Persistent Artifact Storage**: Built-in key-value storage API for artifacts, enabling features like journals, trackers, leaderboards, and collaborative tools with both personal and shared data scopes across sessions.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience. - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support using your choice of 9 vector databases and multiple content extraction engines (Tika, Docling, Document Intelligence, Mistral OCR, External loaders). Load documents directly into chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 🔍 **Web Search for RAG**: Perform web searches using 15+ providers including `SearXNG`, `Google PSE`, `Brave Search`, `Kagi`, `Mojeek`, `Tavily`, `Perplexity`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `SearchApi`, `SerpApi`, `Bing`, `Jina`, `Exa`, `Sougou`, `Azure AI Search`, and `Ollama Cloud`, injecting results directly into your chat experience.
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
- 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content. - 🎨 **Image Generation & Editing Integration**: Create and edit images using multiple engines including OpenAI's DALL-E, Gemini, ComfyUI (local), and AUTOMATIC1111 (local), with support for both generation and prompt-based editing workflows.
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel. - ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators. - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
- 🗄️ **Flexible Database & Storage Options**: Choose from SQLite (with optional encryption), PostgreSQL, or configure cloud storage backends (S3, Google Cloud Storage, Azure Blob Storage) for scalable deployments.
- 🔍 **Advanced Vector Database Support**: Select from 9 vector database options including ChromaDB, PGVector, Qdrant, Milvus, Elasticsearch, OpenSearch, Pinecone, S3Vector, and Oracle 23ai for optimal RAG performance.
- 🔐 **Enterprise Authentication**: Full support for LDAP/Active Directory integration, SCIM 2.0 automated provisioning, and SSO via trusted headers alongside OAuth providers. Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
- ☁️ **Cloud-Native Integration**: Native support for Google Drive and OneDrive/SharePoint file picking, enabling seamless document import from enterprise cloud storage.
- 📊 **Production Observability**: Built-in OpenTelemetry support for traces, metrics, and logs, enabling comprehensive monitoring with your existing observability stack.
- ⚖️ **Horizontal Scalability**: Redis-backed session management and WebSocket support for multi-worker and multi-node deployments behind load balancers.
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors! - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more. - 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.

View file

@ -287,25 +287,30 @@ class AppConfig:
# WEBUI_AUTH (Required for security) # WEBUI_AUTH (Required for security)
#################################### ####################################
ENABLE_API_KEY = PersistentConfig( ENABLE_API_KEYS = PersistentConfig(
"ENABLE_API_KEY", "ENABLE_API_KEYS",
"auth.api_key.enable", "auth.enable_api_keys",
os.environ.get("ENABLE_API_KEY", "True").lower() == "true", os.environ.get("ENABLE_API_KEYS", "False").lower() == "true",
) )
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = PersistentConfig( ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = PersistentConfig(
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS",
"auth.api_key.endpoint_restrictions", "auth.api_key.endpoint_restrictions",
os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False").lower() == "true", os.environ.get(
"ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS",
os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False"),
).lower()
== "true",
) )
API_KEY_ALLOWED_ENDPOINTS = PersistentConfig( API_KEYS_ALLOWED_ENDPOINTS = PersistentConfig(
"API_KEY_ALLOWED_ENDPOINTS", "API_KEYS_ALLOWED_ENDPOINTS",
"auth.api_key.allowed_endpoints", "auth.api_key.allowed_endpoints",
os.environ.get("API_KEY_ALLOWED_ENDPOINTS", ""), os.environ.get(
"API_KEYS_ALLOWED_ENDPOINTS", os.environ.get("API_KEY_ALLOWED_ENDPOINTS", "")
),
) )
JWT_EXPIRES_IN = PersistentConfig( JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "4w") "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "4w")
) )
@ -1124,6 +1129,7 @@ ENABLE_LOGIN_FORM = PersistentConfig(
os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
) )
ENABLE_PASSWORD_AUTH = os.environ.get("ENABLE_PASSWORD_AUTH", "True").lower() == "true"
DEFAULT_LOCALE = PersistentConfig( DEFAULT_LOCALE = PersistentConfig(
"DEFAULT_LOCALE", "DEFAULT_LOCALE",
@ -1135,6 +1141,12 @@ DEFAULT_MODELS = PersistentConfig(
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
) )
DEFAULT_PINNED_MODELS = PersistentConfig(
"DEFAULT_PINNED_MODELS",
"ui.default_pinned_models",
os.environ.get("DEFAULT_PINNED_MODELS", None),
)
try: try:
default_prompt_suggestions = json.loads( default_prompt_suggestions = json.loads(
os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]") os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]")
@ -1191,6 +1203,12 @@ DEFAULT_USER_ROLE = PersistentConfig(
os.getenv("DEFAULT_USER_ROLE", "pending"), os.getenv("DEFAULT_USER_ROLE", "pending"),
) )
DEFAULT_GROUP_ID = PersistentConfig(
"DEFAULT_GROUP_ID",
"ui.default_group_id",
os.environ.get("DEFAULT_GROUP_ID", ""),
)
PENDING_USER_OVERLAY_TITLE = PersistentConfig( PENDING_USER_OVERLAY_TITLE = PersistentConfig(
"PENDING_USER_OVERLAY_TITLE", "PENDING_USER_OVERLAY_TITLE",
"ui.pending_user_overlay_title", "ui.pending_user_overlay_title",
@ -1230,6 +1248,40 @@ USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
) )
USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT", "False").lower() == "true"
)
USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT", "False").lower() == "true"
)
USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = ( USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = (
os.environ.get( os.environ.get(
"USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", "False" "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", "False"
@ -1237,8 +1289,10 @@ USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = (
== "true" == "true"
) )
USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = ( USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING = (
os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower() os.environ.get(
"USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING", "False"
).lower()
== "true" == "true"
) )
@ -1249,6 +1303,11 @@ USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING = (
== "true" == "true"
) )
USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = ( USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = (
os.environ.get( os.environ.get(
"USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING", "False" "USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING", "False"
@ -1256,6 +1315,12 @@ USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = (
== "true" == "true"
) )
USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = ( USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = (
os.environ.get( os.environ.get(
"USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING", "False" "USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING", "False"
@ -1264,6 +1329,17 @@ USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = (
) )
USER_PERMISSIONS_NOTES_ALLOW_SHARING = (
os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower()
== "true"
)
USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = (
os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower()
== "true"
)
USER_PERMISSIONS_CHAT_CONTROLS = ( USER_PERMISSIONS_CHAT_CONTROLS = (
os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true" os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true"
) )
@ -1366,6 +1442,10 @@ USER_PERMISSIONS_FEATURES_NOTES = (
os.environ.get("USER_PERMISSIONS_FEATURES_NOTES", "True").lower() == "true" os.environ.get("USER_PERMISSIONS_FEATURES_NOTES", "True").lower() == "true"
) )
USER_PERMISSIONS_FEATURES_API_KEYS = (
os.environ.get("USER_PERMISSIONS_FEATURES_API_KEYS", "False").lower() == "true"
)
DEFAULT_USER_PERMISSIONS = { DEFAULT_USER_PERMISSIONS = {
"workspace": { "workspace": {
@ -1373,12 +1453,23 @@ DEFAULT_USER_PERMISSIONS = {
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
"models_import": USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT,
"models_export": USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT,
"prompts_import": USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT,
"prompts_export": USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT,
"tools_import": USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT,
"tools_export": USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT,
}, },
"sharing": { "sharing": {
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING,
"public_models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING, "public_models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING,
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING,
"public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING, "public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING,
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING,
"public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING, "public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING,
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING,
"public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING, "public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING,
"notes": USER_PERMISSIONS_NOTES_ALLOW_SHARING,
"public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING, "public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING,
}, },
"chat": { "chat": {
@ -1403,6 +1494,7 @@ DEFAULT_USER_PERMISSIONS = {
"temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED, "temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED,
}, },
"features": { "features": {
"api_keys": USER_PERMISSIONS_FEATURES_API_KEYS,
"direct_tool_servers": USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS, "direct_tool_servers": USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS,
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH, "web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION, "image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
@ -1816,6 +1908,38 @@ Output:
#### Output: #### Output:
""" """
VOICE_MODE_PROMPT_TEMPLATE = PersistentConfig(
"VOICE_MODE_PROMPT_TEMPLATE",
"task.voice.prompt_template",
os.environ.get("VOICE_MODE_PROMPT_TEMPLATE", ""),
)
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE = """You are a friendly, concise voice assistant.
Everything you say will be spoken aloud.
Keep responses short, clear, and natural.
STYLE:
- Use simple words and short sentences.
- Sound warm and conversational.
- Avoid long explanations, lists, or complex phrasing.
BEHAVIOR:
- Give the quickest helpful answer first.
- Offer extra detail only if needed.
- Ask for clarification only when necessary.
VOICE OPTIMIZATION:
- Break information into small, easy-to-hear chunks.
- Avoid dense wording or anything that sounds like reading text.
ERROR HANDLING:
- If unsure, say so briefly and offer options.
- If something is unsafe or impossible, decline kindly and suggest a safe alternative.
Stay consistent, helpful, and easy to listen to."""
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template", "task.tools.prompt_template",
@ -2056,6 +2180,11 @@ ENABLE_QDRANT_MULTITENANCY_MODE = (
) )
QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui") QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
WEAVIATE_HTTP_HOST = os.environ.get("WEAVIATE_HTTP_HOST", "")
WEAVIATE_HTTP_PORT = int(os.environ.get("WEAVIATE_HTTP_PORT", "8080"))
WEAVIATE_GRPC_PORT = int(os.environ.get("WEAVIATE_GRPC_PORT", "50051"))
WEAVIATE_API_KEY = os.environ.get("WEAVIATE_API_KEY")
# OpenSearch # OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true" OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true"
@ -2086,6 +2215,16 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536") os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
) )
PGVECTOR_USE_HALFVEC = os.getenv("PGVECTOR_USE_HALFVEC", "false").lower() == "true"
if PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH > 2000 and not PGVECTOR_USE_HALFVEC:
raise ValueError(
"PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH is set to "
f"{PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH}, which exceeds the 2000 dimension limit of the "
"'vector' type. Set PGVECTOR_USE_HALFVEC=true to enable the 'halfvec' "
"type required for high-dimensional embeddings."
)
PGVECTOR_CREATE_EXTENSION = ( PGVECTOR_CREATE_EXTENSION = (
os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true" os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true"
) )
@ -2135,6 +2274,40 @@ else:
except Exception: except Exception:
PGVECTOR_POOL_RECYCLE = 3600 PGVECTOR_POOL_RECYCLE = 3600
PGVECTOR_INDEX_METHOD = os.getenv("PGVECTOR_INDEX_METHOD", "").strip().lower()
if PGVECTOR_INDEX_METHOD not in ("ivfflat", "hnsw", ""):
PGVECTOR_INDEX_METHOD = ""
PGVECTOR_HNSW_M = os.environ.get("PGVECTOR_HNSW_M", 16)
if PGVECTOR_HNSW_M == "":
PGVECTOR_HNSW_M = 16
else:
try:
PGVECTOR_HNSW_M = int(PGVECTOR_HNSW_M)
except Exception:
PGVECTOR_HNSW_M = 16
PGVECTOR_HNSW_EF_CONSTRUCTION = os.environ.get("PGVECTOR_HNSW_EF_CONSTRUCTION", 64)
if PGVECTOR_HNSW_EF_CONSTRUCTION == "":
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
else:
try:
PGVECTOR_HNSW_EF_CONSTRUCTION = int(PGVECTOR_HNSW_EF_CONSTRUCTION)
except Exception:
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
PGVECTOR_IVFFLAT_LISTS = os.environ.get("PGVECTOR_IVFFLAT_LISTS", 100)
if PGVECTOR_IVFFLAT_LISTS == "":
PGVECTOR_IVFFLAT_LISTS = 100
else:
try:
PGVECTOR_IVFFLAT_LISTS = int(PGVECTOR_IVFFLAT_LISTS)
except Exception:
PGVECTOR_IVFFLAT_LISTS = 100
# Pinecone # Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
@ -2510,6 +2683,13 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
) )
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = PersistentConfig(
"ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS",
"rag.enable_hybrid_search_enriched_texts",
os.environ.get("ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS", "False").lower()
== "true",
)
RAG_FULL_CONTEXT = PersistentConfig( RAG_FULL_CONTEXT = PersistentConfig(
"RAG_FULL_CONTEXT", "RAG_FULL_CONTEXT",
"rag.full_context", "rag.full_context",
@ -2697,10 +2877,6 @@ Provide a clear and direct response to the user's query, including inline citati
<context> <context>
{{CONTEXT}} {{CONTEXT}}
</context> </context>
<user_query>
{{QUERY}}
</user_query>
""" """
RAG_TEMPLATE = PersistentConfig( RAG_TEMPLATE = PersistentConfig(
@ -2753,6 +2929,26 @@ ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
) )
DEFAULT_WEB_FETCH_FILTER_LIST = [
"!169.254.169.254",
"!fd00:ec2::254",
"!metadata.google.internal",
"!metadata.azure.com",
"!100.100.100.200",
]
web_fetch_filter_list = os.getenv("WEB_FETCH_FILTER_LIST", "")
if web_fetch_filter_list == "":
web_fetch_filter_list = []
else:
web_fetch_filter_list = [
item.strip() for item in web_fetch_filter_list.split(",") if item.strip()
]
WEB_FETCH_FILTER_LIST = list(set(DEFAULT_WEB_FETCH_FILTER_LIST + web_fetch_filter_list))
YOUTUBE_LOADER_LANGUAGE = PersistentConfig( YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
"YOUTUBE_LOADER_LANGUAGE", "YOUTUBE_LOADER_LANGUAGE",
"rag.youtube_loader_language", "rag.youtube_loader_language",
@ -2811,6 +3007,7 @@ WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
# "wikipedia.com", # "wikipedia.com",
# "wikimedia.org", # "wikimedia.org",
# "wikidata.org", # "wikidata.org",
# "!stackoverflow.com",
], ],
) )
@ -2982,6 +3179,24 @@ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""), os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
) )
AZURE_AI_SEARCH_API_KEY = PersistentConfig(
"AZURE_AI_SEARCH_API_KEY",
"rag.web.search.azure_ai_search_api_key",
os.environ.get("AZURE_AI_SEARCH_API_KEY", ""),
)
AZURE_AI_SEARCH_ENDPOINT = PersistentConfig(
"AZURE_AI_SEARCH_ENDPOINT",
"rag.web.search.azure_ai_search_endpoint",
os.environ.get("AZURE_AI_SEARCH_ENDPOINT", ""),
)
AZURE_AI_SEARCH_INDEX_NAME = PersistentConfig(
"AZURE_AI_SEARCH_INDEX_NAME",
"rag.web.search.azure_ai_search_index_name",
os.environ.get("AZURE_AI_SEARCH_INDEX_NAME", ""),
)
EXA_API_KEY = PersistentConfig( EXA_API_KEY = PersistentConfig(
"EXA_API_KEY", "EXA_API_KEY",
"rag.web.search.exa_api_key", "rag.web.search.exa_api_key",
@ -3006,6 +3221,12 @@ PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig(
os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"), os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
) )
PERPLEXITY_SEARCH_API_URL = PersistentConfig(
"PERPLEXITY_SEARCH_API_URL",
"rag.web.search.perplexity_search_api_url",
os.getenv("PERPLEXITY_SEARCH_API_URL", "https://api.perplexity.ai/search"),
)
SOUGOU_API_SID = PersistentConfig( SOUGOU_API_SID = PersistentConfig(
"SOUGOU_API_SID", "SOUGOU_API_SID",
"rag.web.search.sougou_api_sid", "rag.web.search.sougou_api_sid",
@ -3131,10 +3352,9 @@ try:
except json.JSONDecodeError: except json.JSONDecodeError:
automatic1111_params = {} automatic1111_params = {}
AUTOMATIC1111_PARAMS = PersistentConfig( AUTOMATIC1111_PARAMS = PersistentConfig(
"AUTOMATIC1111_PARAMS", "AUTOMATIC1111_PARAMS",
"image_generation.automatic1111.api_auth", "image_generation.automatic1111.api_params",
automatic1111_params, automatic1111_params,
) )
@ -3290,6 +3510,18 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
) )
images_openai_params = os.getenv("IMAGES_OPENAI_PARAMS", "")
try:
images_openai_params = json.loads(images_openai_params)
except json.JSONDecodeError:
images_openai_params = {}
IMAGES_OPENAI_API_PARAMS = PersistentConfig(
"IMAGES_OPENAI_API_PARAMS", "image_generation.openai.params", images_openai_params
)
IMAGES_GEMINI_API_BASE_URL = PersistentConfig( IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
"IMAGES_GEMINI_API_BASE_URL", "IMAGES_GEMINI_API_BASE_URL",
"image_generation.gemini.api_base_url", "image_generation.gemini.api_base_url",
@ -3307,6 +3539,11 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig(
os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""), os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
) )
ENABLE_IMAGE_EDIT = PersistentConfig(
"ENABLE_IMAGE_EDIT",
"images.edit.enable",
os.environ.get("ENABLE_IMAGE_EDIT", "").lower() == "true",
)
IMAGE_EDIT_ENGINE = PersistentConfig( IMAGE_EDIT_ENGINE = PersistentConfig(
"IMAGE_EDIT_ENGINE", "IMAGE_EDIT_ENGINE",

View file

@ -45,7 +45,7 @@ class ERROR_MESSAGES(str, Enum):
) )
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INVALID_PASSWORD = ( INCORRECT_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again." "The password provided is incorrect. Please check for typos and try again."
) )
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance." INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
@ -105,6 +105,10 @@ class ERROR_MESSAGES(str, Enum):
) )
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding." FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
INVALID_PASSWORD = lambda err="": (
err if err else "The password does not meet the required validation criteria."
)
class TASKS(str, Enum): class TASKS(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:

View file

@ -8,6 +8,8 @@ import shutil
from uuid import uuid4 from uuid import uuid4
from pathlib import Path from pathlib import Path
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
import re
import markdown import markdown
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -135,6 +137,9 @@ else:
PACKAGE_DATA = {"version": "0.0.0"} PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"] VERSION = PACKAGE_DATA["version"]
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "")
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4())) INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
@ -426,6 +431,17 @@ WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
) )
ENABLE_PASSWORD_VALIDATION = (
os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true"
)
PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get(
"PASSWORD_VALIDATION_REGEX_PATTERN",
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$",
)
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN)
BYPASS_MODEL_ACCESS_CONTROL = ( BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
) )
@ -493,7 +509,10 @@ OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
# SCIM Configuration # SCIM Configuration
#################################### ####################################
SCIM_ENABLED = os.environ.get("SCIM_ENABLED", "False").lower() == "true" ENABLE_SCIM = (
os.environ.get("ENABLE_SCIM", os.environ.get("SCIM_ENABLED", "False")).lower()
== "true"
)
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "") SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
#################################### ####################################
@ -541,6 +560,10 @@ else:
# CHAT # CHAT
#################################### ####################################
ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION = (
os.environ.get("REPLACE_IMAGE_URLS_IN_CHAT_RESPONSE", "False").lower() == "true"
)
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get( CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1" "CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
) )
@ -569,6 +592,21 @@ else:
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get(
"CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", ""
)
if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "":
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
else:
try:
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
)
except Exception:
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
#################################### ####################################
# WEBSOCKET SUPPORT # WEBSOCKET SUPPORT
#################################### ####################################
@ -580,6 +618,17 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_OPTIONS = os.environ.get("WEBSOCKET_REDIS_OPTIONS", "")
if WEBSOCKET_REDIS_OPTIONS == "":
log.debug("No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None")
WEBSOCKET_REDIS_OPTIONS = None
else:
try:
WEBSOCKET_REDIS_OPTIONS = json.loads(WEBSOCKET_REDIS_OPTIONS)
except Exception:
log.warning("Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None")
WEBSOCKET_REDIS_OPTIONS = None
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_CLUSTER = ( WEBSOCKET_REDIS_CLUSTER = (
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true" os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
@ -594,6 +643,23 @@ except ValueError:
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
WEBSOCKET_SERVER_LOGGING = (
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
)
WEBSOCKET_SERVER_ENGINEIO_LOGGING = (
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
)
WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20")
try:
WEBSOCKET_SERVER_PING_TIMEOUT = int(WEBSOCKET_SERVER_PING_TIMEOUT)
except ValueError:
WEBSOCKET_SERVER_PING_TIMEOUT = 20
WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get("WEBSOCKET_SERVER_PING_INTERVAL", "25")
try:
WEBSOCKET_SERVER_PING_INTERVAL = int(WEBSOCKET_SERVER_PING_INTERVAL)
except ValueError:
WEBSOCKET_SERVER_PING_INTERVAL = 25
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@ -706,7 +772,9 @@ if OFFLINE_MODE:
# AUDIT LOGGING # AUDIT LOGGING
#################################### ####################################
# Where to store log file # Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log" # Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to
# provide the whole path, like: /app/audit.log
AUDIT_LOGS_FILE_PATH = os.getenv("AUDIT_LOGS_FILE_PATH", f"{DATA_DIR}/audit.log")
# Maximum size of a file before rotating into a new log file # Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")

View file

@ -160,9 +160,11 @@ from open_webui.config import (
IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_VERSION, IMAGES_OPENAI_API_VERSION,
IMAGES_OPENAI_API_KEY, IMAGES_OPENAI_API_KEY,
IMAGES_OPENAI_API_PARAMS,
IMAGES_GEMINI_API_BASE_URL, IMAGES_GEMINI_API_BASE_URL,
IMAGES_GEMINI_API_KEY, IMAGES_GEMINI_API_KEY,
IMAGES_GEMINI_ENDPOINT_METHOD, IMAGES_GEMINI_ENDPOINT_METHOD,
ENABLE_IMAGE_EDIT,
IMAGE_EDIT_ENGINE, IMAGE_EDIT_ENGINE,
IMAGE_EDIT_MODEL, IMAGE_EDIT_MODEL,
IMAGE_EDIT_SIZE, IMAGE_EDIT_SIZE,
@ -319,6 +321,7 @@ from open_webui.config import (
PERPLEXITY_API_KEY, PERPLEXITY_API_KEY,
PERPLEXITY_MODEL, PERPLEXITY_MODEL,
PERPLEXITY_SEARCH_CONTEXT_USAGE, PERPLEXITY_SEARCH_CONTEXT_USAGE,
PERPLEXITY_SEARCH_API_URL,
SOUGOU_API_SID, SOUGOU_API_SID,
SOUGOU_API_SK, SOUGOU_API_SK,
KAGI_SEARCH_API_KEY, KAGI_SEARCH_API_KEY,
@ -336,6 +339,7 @@ from open_webui.config import (
ENABLE_ONEDRIVE_PERSONAL, ENABLE_ONEDRIVE_PERSONAL,
ENABLE_ONEDRIVE_BUSINESS, ENABLE_ONEDRIVE_BUSINESS,
ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_WEB_LOADER_SSL_VERIFICATION, ENABLE_WEB_LOADER_SSL_VERIFICATION,
ENABLE_GOOGLE_DRIVE_INTEGRATION, ENABLE_GOOGLE_DRIVE_INTEGRATION,
@ -354,9 +358,9 @@ from open_webui.config import (
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
ENABLE_SIGNUP, ENABLE_SIGNUP,
ENABLE_LOGIN_FORM, ENABLE_LOGIN_FORM,
ENABLE_API_KEY, ENABLE_API_KEYS,
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
API_KEY_ALLOWED_ENDPOINTS, API_KEYS_ALLOWED_ENDPOINTS,
ENABLE_CHANNELS, ENABLE_CHANNELS,
ENABLE_NOTES, ENABLE_NOTES,
ENABLE_COMMUNITY_SHARING, ENABLE_COMMUNITY_SHARING,
@ -366,10 +370,12 @@ from open_webui.config import (
BYPASS_ADMIN_ACCESS_CONTROL, BYPASS_ADMIN_ACCESS_CONTROL,
USER_PERMISSIONS, USER_PERMISSIONS,
DEFAULT_USER_ROLE, DEFAULT_USER_ROLE,
DEFAULT_GROUP_ID,
PENDING_USER_OVERLAY_CONTENT, PENDING_USER_OVERLAY_CONTENT,
PENDING_USER_OVERLAY_TITLE, PENDING_USER_OVERLAY_TITLE,
DEFAULT_PROMPT_SUGGESTIONS, DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_MODELS, DEFAULT_MODELS,
DEFAULT_PINNED_MODELS,
DEFAULT_ARENA_MODEL, DEFAULT_ARENA_MODEL,
MODEL_ORDER_LIST, MODEL_ORDER_LIST,
EVALUATION_ARENA_MODELS, EVALUATION_ARENA_MODELS,
@ -428,6 +434,7 @@ from open_webui.config import (
TAGS_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE,
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
VOICE_MODE_PROMPT_TEMPLATE,
QUERY_GENERATION_PROMPT_TEMPLATE, QUERY_GENERATION_PROMPT_TEMPLATE,
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
@ -449,6 +456,7 @@ from open_webui.env import (
SAFE_MODE, SAFE_MODE,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
VERSION, VERSION,
DEPLOYMENT_ID,
INSTANCE_ID, INSTANCE_ID,
WEBUI_BUILD_HASH, WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY, WEBUI_SECRET_KEY,
@ -459,7 +467,7 @@ from open_webui.env import (
WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL, WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
# SCIM # SCIM
SCIM_ENABLED, ENABLE_SCIM,
SCIM_TOKEN, SCIM_TOKEN,
ENABLE_COMPRESSION_MIDDLEWARE, ENABLE_COMPRESSION_MIDDLEWARE,
ENABLE_WEBSOCKET_SUPPORT, ENABLE_WEBSOCKET_SUPPORT,
@ -715,7 +723,7 @@ app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
# #
######################################## ########################################
app.state.SCIM_ENABLED = SCIM_ENABLED app.state.ENABLE_SCIM = ENABLE_SCIM
app.state.SCIM_TOKEN = SCIM_TOKEN app.state.SCIM_TOKEN = SCIM_TOKEN
######################################## ########################################
@ -737,11 +745,11 @@ app.state.config.WEBUI_URL = WEBUI_URL
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY app.state.config.ENABLE_API_KEYS = ENABLE_API_KEYS
app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = ( app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = (
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS
) )
app.state.config.API_KEY_ALLOWED_ENDPOINTS = API_KEY_ALLOWED_ENDPOINTS app.state.config.API_KEYS_ALLOWED_ENDPOINTS = API_KEYS_ALLOWED_ENDPOINTS
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
@ -750,8 +758,13 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PINNED_MODELS = DEFAULT_PINNED_MODELS
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.DEFAULT_GROUP_ID = DEFAULT_GROUP_ID
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
@ -761,7 +774,6 @@ app.state.config.RESPONSE_WATERMARK = RESPONSE_WATERMARK
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
@ -839,6 +851,9 @@ app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = FILE_IMAGE_COMPRESSION_HEIGHT
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = (
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
)
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
@ -958,6 +973,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
app.state.config.PERPLEXITY_SEARCH_API_URL = PERPLEXITY_SEARCH_API_URL
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
@ -1087,6 +1103,7 @@ app.state.config.IMAGE_STEPS = IMAGE_STEPS
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.IMAGES_OPENAI_API_VERSION = IMAGES_OPENAI_API_VERSION app.state.config.IMAGES_OPENAI_API_VERSION = IMAGES_OPENAI_API_VERSION
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.IMAGES_OPENAI_API_PARAMS = IMAGES_OPENAI_API_PARAMS
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
@ -1102,6 +1119,7 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
app.state.config.ENABLE_IMAGE_EDIT = ENABLE_IMAGE_EDIT
app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
@ -1206,6 +1224,7 @@ app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
) )
app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE
######################################## ########################################
@ -1216,6 +1235,10 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
app.state.MODELS = {} app.state.MODELS = {}
# Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:
app.add_middleware(CompressMiddleware)
class RedirectMiddleware(BaseHTTPMiddleware): class RedirectMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
@ -1257,14 +1280,53 @@ class RedirectMiddleware(BaseHTTPMiddleware):
return response return response
# Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:
app.add_middleware(CompressMiddleware)
app.add_middleware(RedirectMiddleware) app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(SecurityHeadersMiddleware)
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
auth_header = request.headers.get("Authorization")
token = None
if auth_header:
scheme, token = auth_header.split(" ")
# Only apply restrictions if an sk- API key is used
if token and token.startswith("sk-"):
# Check if restrictions are enabled
if request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS
).split(",")
if path.strip()
]
request_path = request.url.path
# Match exact path or prefix path
is_allowed = any(
request_path == allowed or request_path.startswith(allowed + "/")
for allowed in allowed_paths
)
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "API key not allowed to access this endpoint."
},
)
response = await call_next(request)
return response
app.add_middleware(APIKeyRestrictionMiddleware)
@app.middleware("http") @app.middleware("http")
async def commit_session_after_request(request: Request, call_next): async def commit_session_after_request(request: Request, call_next):
response = await call_next(request) response = await call_next(request)
@ -1280,7 +1342,7 @@ async def check_url(request: Request, call_next):
request.headers.get("Authorization") request.headers.get("Authorization")
) )
request.state.enable_api_key = app.state.config.ENABLE_API_KEY request.state.enable_api_keys = app.state.config.ENABLE_API_KEYS
response = await call_next(request) response = await call_next(request)
process_time = int(time.time()) - start_time process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time) response.headers["X-Process-Time"] = str(process_time)
@ -1355,7 +1417,7 @@ app.include_router(
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
# SCIM 2.0 API for identity management # SCIM 2.0 API for identity management
if SCIM_ENABLED: if ENABLE_SCIM:
app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"]) app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"])
@ -1392,6 +1454,10 @@ async def get_models(
if "pipeline" in model and model["pipeline"].get("type", None) == "filter": if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
continue continue
# Remove profile image URL to reduce payload size
if model.get("info", {}).get("meta", {}).get("profile_image_url"):
model["info"]["meta"].pop("profile_image_url", None)
try: try:
model_tags = [ model_tags = [
tag.get("name") tag.get("name")
@ -1514,6 +1580,9 @@ async def chat_completion(
reasoning_tags = form_data.get("params", {}).get("reasoning_tags") reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
# Model Params # Model Params
if model_info_params.get("stream_response") is not None:
form_data["stream"] = model_info_params.get("stream_response")
if model_info_params.get("stream_delta_chunk_size"): if model_info_params.get("stream_delta_chunk_size"):
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size") stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
@ -1783,7 +1852,7 @@ async def get_app_config(request: Request):
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION, "enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
"enable_ldap": app.state.config.ENABLE_LDAP, "enable_ldap": app.state.config.ENABLE_LDAP,
"enable_api_key": app.state.config.ENABLE_API_KEY, "enable_api_keys": app.state.config.ENABLE_API_KEYS,
"enable_signup": app.state.config.ENABLE_SIGNUP, "enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT, "enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
@ -1821,6 +1890,7 @@ async def get_app_config(request: Request):
**( **(
{ {
"default_models": app.state.config.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_pinned_models": app.state.config.DEFAULT_PINNED_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"user_count": user_count, "user_count": user_count,
"code": { "code": {
@ -1922,6 +1992,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
async def get_app_version(): async def get_app_version():
return { return {
"version": VERSION, "version": VERSION,
"deployment_id": DEPLOYMENT_ID,
} }

View file

@ -0,0 +1,146 @@
"""add_group_member_table
Revision ID: 37f288994c47
Revises: a5c220713937
Create Date: 2025-11-17 03:45:25.123939
"""
import uuid
import time
import json
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "37f288994c47"
down_revision: Union[str, None] = "a5c220713937"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 1. Create new table
op.create_table(
"group_member",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column(
"group_id",
sa.Text(),
sa.ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"user_id",
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
)
connection = op.get_bind()
# 2. Read existing group with user_ids JSON column
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(group_table.c.id, group_table.c.user_ids)
).fetchall()
print(results)
# 3. Insert members into group_member table
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
now = int(time.time())
for group_id, user_ids in results:
if not user_ids:
continue
if isinstance(user_ids, str):
try:
user_ids = json.loads(user_ids)
except Exception:
continue # skip invalid JSON
if not isinstance(user_ids, list):
continue
rows = [
{
"id": str(uuid.uuid4()),
"group_id": group_id,
"user_id": uid,
"created_at": now,
"updated_at": now,
}
for uid in user_ids
]
if rows:
connection.execute(gm_table.insert(), rows)
# 4. Optionally drop the old column
with op.batch_alter_table("group") as batch:
batch.drop_column("user_ids")
def downgrade():
# Reverse: restore user_ids column
with op.batch_alter_table("group") as batch:
batch.add_column(sa.Column("user_ids", sa.JSON()))
connection = op.get_bind()
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()),
)
# Build JSON arrays again
results = connection.execute(sa.select(group_table.c.id)).fetchall()
for (group_id,) in results:
members = connection.execute(
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
).fetchall()
member_ids = [m[0] for m in members]
connection.execute(
group_table.update()
.where(group_table.c.id == group_id)
.values(user_ids=member_ids)
)
# Drop the new table
op.drop_table("group_member")

View file

@ -7,7 +7,6 @@ from open_webui.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text from sqlalchemy import Boolean, Column, String, Text
from open_webui.utils.auth import verify_password
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -20,7 +19,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Auth(Base): class Auth(Base):
__tablename__ = "auth" __tablename__ = "auth"
id = Column(String, primary_key=True) id = Column(String, primary_key=True, unique=True)
email = Column(String) email = Column(String)
password = Column(Text) password = Column(Text)
active = Column(Boolean) active = Column(Boolean)
@ -122,7 +121,9 @@ class AuthsTable:
else: else:
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(
self, email: str, verify_password: callable
) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email)
@ -133,7 +134,7 @@ class AuthsTable:
with get_db() as db: with get_db() as db:
auth = db.query(Auth).filter_by(id=user.id, active=True).first() auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth: if auth:
if verify_password(password, auth.password): if verify_password(auth.password):
return user return user
else: else:
return None return None

View file

@ -19,7 +19,7 @@ from sqlalchemy.sql import exists
class Channel(Base): class Channel(Base):
__tablename__ = "channel" __tablename__ = "channel"
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
type = Column(Text, nullable=True) type = Column(Text, nullable=True)

View file

@ -26,7 +26,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base): class Chat(Base):
__tablename__ = "chat" __tablename__ = "chat"
id = Column(String, primary_key=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
title = Column(Text) title = Column(Text)
chat = Column(JSON) chat = Column(JSON)
@ -92,6 +92,10 @@ class ChatImportForm(ChatForm):
updated_at: Optional[int] = None updated_at: Optional[int] = None
class ChatsImportForm(BaseModel):
chats: list[ChatImportForm]
class ChatTitleMessagesForm(BaseModel): class ChatTitleMessagesForm(BaseModel):
title: str title: str
messages: list[dict] messages: list[dict]
@ -123,6 +127,43 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def _clean_null_bytes(self, obj):
"""
Recursively remove actual null bytes (\x00) and unicode escape \\u0000
from strings inside dict/list structures.
Safe for JSON objects.
"""
if isinstance(obj, str):
return obj.replace("\x00", "").replace("\u0000", "")
elif isinstance(obj, dict):
return {k: self._clean_null_bytes(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._clean_null_bytes(v) for v in obj]
return obj
def _sanitize_chat_row(self, chat_item):
"""
Clean a Chat SQLAlchemy model's title + chat JSON,
and return True if anything changed.
"""
changed = False
# Clean title
if chat_item.title:
cleaned = self._clean_null_bytes(chat_item.title)
if cleaned != chat_item.title:
chat_item.title = cleaned
changed = True
# Clean JSON
if chat_item.chat:
cleaned = self._clean_null_bytes(chat_item.chat)
if cleaned != chat_item.chat:
chat_item.chat = cleaned
changed = True
return changed
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -130,68 +171,76 @@ class ChatTable:
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": ( "title": self._clean_null_bytes(
form_data.chat["title"] form_data.chat["title"]
if "title" in form_data.chat if "title" in form_data.chat
else "New Chat" else "New Chat"
), ),
"chat": form_data.chat, "chat": self._clean_null_bytes(form_data.chat),
"folder_id": form_data.folder_id, "folder_id": form_data.folder_id,
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
result = Chat(**chat.model_dump()) chat_item = Chat(**chat.model_dump())
db.add(result) db.add(chat_item)
db.commit() db.commit()
db.refresh(result) db.refresh(chat_item)
return ChatModel.model_validate(result) if result else None return ChatModel.model_validate(chat_item) if chat_item else None
def import_chat( def _chat_import_form_to_chat_model(
self, user_id: str, form_data: ChatImportForm self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]: ) -> ChatModel:
with get_db() as db: id = str(uuid.uuid4())
id = str(uuid.uuid4()) chat = ChatModel(
chat = ChatModel( **{
**{ "id": id,
"id": id, "user_id": user_id,
"user_id": user_id, "title": self._clean_null_bytes(
"title": ( form_data.chat["title"] if "title" in form_data.chat else "New Chat"
form_data.chat["title"] ),
if "title" in form_data.chat "chat": self._clean_null_bytes(form_data.chat),
else "New Chat" "meta": form_data.meta,
), "pinned": form_data.pinned,
"chat": form_data.chat, "folder_id": form_data.folder_id,
"meta": form_data.meta, "created_at": (
"pinned": form_data.pinned, form_data.created_at if form_data.created_at else int(time.time())
"folder_id": form_data.folder_id, ),
"created_at": ( "updated_at": (
form_data.created_at form_data.updated_at if form_data.updated_at else int(time.time())
if form_data.created_at ),
else int(time.time()) }
), )
"updated_at": ( return chat
form_data.updated_at
if form_data.updated_at
else int(time.time())
),
}
)
result = Chat(**chat.model_dump()) def import_chats(
db.add(result) self, user_id: str, chat_import_forms: list[ChatImportForm]
) -> list[ChatModel]:
with get_db() as db:
chats = []
for form_data in chat_import_forms:
chat = self._chat_import_form_to_chat_model(user_id, form_data)
chats.append(Chat(**chat.model_dump()))
db.add_all(chats)
db.commit() db.commit()
db.refresh(result) return [ChatModel.model_validate(chat) for chat in chats]
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat_item = db.get(Chat, id) chat_item = db.get(Chat, id)
chat_item.chat = chat chat_item.chat = self._clean_null_bytes(chat)
chat_item.title = chat["title"] if "title" in chat else "New Chat" chat_item.title = (
self._clean_null_bytes(chat["title"])
if "title" in chat
else "New Chat"
)
chat_item.updated_at = int(time.time()) chat_item.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(chat_item) db.refresh(chat_item)
@ -297,6 +346,27 @@ class ChatTable:
chat["history"] = history chat["history"] = history
return self.update_chat_by_id(id, chat) return self.update_chat_by_id(id, chat)
def add_message_files_by_id_and_message_id(
self, id: str, message_id: str, files: list[dict]
) -> list[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
message_files = []
if message_id in history.get("messages", {}):
message_files = history["messages"][message_id].get("files", [])
message_files = message_files + files
history["messages"][message_id]["files"] = message_files
chat["history"] = history
self.update_chat_by_id(id, chat)
return message_files
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db: with get_db() as db:
# Get the existing chat to share # Get the existing chat to share
@ -405,6 +475,7 @@ class ChatTable:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
chat.folder_id = None
chat.updated_at = int(time.time()) chat.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(chat) db.refresh(chat)
@ -561,8 +632,15 @@ class ChatTable:
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat_item = db.get(Chat, id)
return ChatModel.model_validate(chat) if chat_item is None:
return None
if self._sanitize_chat_row(chat_item):
db.commit()
db.refresh(chat_item)
return ChatModel.model_validate(chat_item)
except Exception: except Exception:
return None return None
@ -767,24 +845,30 @@ class ChatTable:
elif dialect_name == "postgresql": elif dialect_name == "postgresql":
# PostgreSQL doesn't allow null bytes in text. We filter those out by checking # PostgreSQL doesn't allow null bytes in text. We filter those out by checking
# the JSON representation for \u0000 before attempting text extraction # the JSON representation for \u0000 before attempting text extraction
postgres_content_sql = (
"EXISTS (" # Safety filter: JSON field must not contain \u0000
" SELECT 1 " query = query.filter(text("Chat.chat::text NOT LIKE '%\\\\u0000%'"))
" FROM json_array_elements(Chat.chat->'messages') AS message "
" WHERE message->'content' IS NOT NULL " # Safety filter: title must not contain actual null bytes
" AND (message->'content')::text NOT LIKE '%\\u0000%' "
" AND LOWER(message->>'content') LIKE '%' || :content_key || '%'"
")"
)
postgres_content_clause = text(postgres_content_sql)
# Also filter out chats with null bytes in title
query = query.filter(text("Chat.title::text NOT LIKE '%\\x00%'")) query = query.filter(text("Chat.title::text NOT LIKE '%\\x00%'"))
postgres_content_sql = """
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE json_typeof(message->'content') = 'string'
AND LOWER(message->>'content') LIKE '%' || :content_key || '%'
)
"""
postgres_content_clause = text(postgres_content_sql)
query = query.filter( query = query.filter(
or_( or_(
Chat.title.ilike(bindparam("title_key")), Chat.title.ilike(bindparam("title_key")),
postgres_content_clause, postgres_content_clause,
).params(title_key=f"%{search_text}%", content_key=search_text) )
) ).params(title_key=f"%{search_text}%", content_key=search_text.lower())
# Check if there are any tags to filter, it should have all the tags # Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids: if "none" in tag_ids:
@ -1059,6 +1143,20 @@ class ChatTable:
except Exception: except Exception:
return False return False
def move_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str, new_folder_id: Optional[str]
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update(
{"folder_id": new_folder_id}
)
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:

View file

@ -4,7 +4,7 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats from open_webui.models.users import User
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -21,7 +21,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Feedback(Base): class Feedback(Base):
__tablename__ = "feedback" __tablename__ = "feedback"
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
version = Column(BigInteger, default=0) version = Column(BigInteger, default=0)
type = Column(Text) type = Column(Text)
@ -92,6 +92,28 @@ class FeedbackForm(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str = "pending"
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
class FeedbackUserResponse(FeedbackResponse):
user: Optional[UserResponse] = None
class FeedbackListResponse(BaseModel):
items: list[FeedbackUserResponse]
total: int
class FeedbackTable: class FeedbackTable:
def insert_new_feedback( def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm self, user_id: str, form_data: FeedbackForm
@ -143,6 +165,70 @@ class FeedbackTable:
except Exception: except Exception:
return None return None
def get_feedback_items(
self, filter: dict = {}, skip: int = 0, limit: int = 30
) -> FeedbackListResponse:
with get_db() as db:
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
if filter:
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "username":
if direction == "asc":
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "model_id":
# it's stored in feedback.data['model_id']
if direction == "asc":
query = query.order_by(
Feedback.data["model_id"].as_string().asc()
)
else:
query = query.order_by(
Feedback.data["model_id"].as_string().desc()
)
elif order_by == "rating":
# it's stored in feedback.data['rating']
if direction == "asc":
query = query.order_by(
Feedback.data["rating"].as_string().asc()
)
else:
query = query.order_by(
Feedback.data["rating"].as_string().desc()
)
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Feedback.updated_at.asc())
else:
query = query.order_by(Feedback.updated_at.desc())
else:
query = query.order_by(Feedback.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
feedbacks = []
for feedback, user in items:
feedback_model = FeedbackModel.model_validate(feedback)
user_model = UserResponse.model_validate(user)
feedbacks.append(
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
)
return FeedbackListResponse(items=feedbacks, total=total)
def get_all_feedbacks(self) -> list[FeedbackModel]: def get_all_feedbacks(self) -> list[FeedbackModel]:
with get_db() as db: with get_db() as db:
return [ return [

View file

@ -17,7 +17,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class File(Base): class File(Base):
__tablename__ = "file" __tablename__ = "file"
id = Column(String, primary_key=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
hash = Column(Text, nullable=True) hash = Column(Text, nullable=True)
@ -98,6 +98,12 @@ class FileForm(BaseModel):
access_control: Optional[dict] = None access_control: Optional[dict] = None
class FileUpdateForm(BaseModel):
hash: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
class FilesTable: class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db: with get_db() as db:
@ -204,6 +210,29 @@ class FilesTable:
for file in db.query(File).filter_by(user_id=user_id).all() for file in db.query(File).filter_by(user_id=user_id).all()
] ]
def update_file_by_id(
self, id: str, form_data: FileUpdateForm
) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
if form_data.hash is not None:
file.hash = form_data.hash
if form_data.data is not None:
file.data = {**(file.data if file.data else {}), **form_data.data}
if form_data.meta is not None:
file.meta = {**(file.meta if file.meta else {}), **form_data.meta}
file.updated_at = int(time.time())
db.commit()
return FileModel.model_validate(file)
except Exception as e:
log.exception(f"Error updating file completely by id: {e}")
return None
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
with get_db() as db: with get_db() as db:
try: try:

View file

@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Folder(Base): class Folder(Base):
__tablename__ = "folder" __tablename__ = "folder"
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True, unique=True)
parent_id = Column(Text, nullable=True) parent_id = Column(Text, nullable=True)
user_id = Column(Text) user_id = Column(Text)
name = Column(Text) name = Column(Text)

View file

@ -19,7 +19,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Function(Base): class Function(Base):
__tablename__ = "function" __tablename__ = "function"
id = Column(String, primary_key=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
name = Column(Text) name = Column(Text)
type = Column(Text) type = Column(Text)

View file

@ -11,7 +11,7 @@ from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,7 +35,6 @@ class Group(Base):
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True) permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger) created_at = Column(BigInteger)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
@ -53,12 +52,33 @@ class GroupModel(BaseModel):
meta: Optional[dict] = None meta: Optional[dict] = None
permissions: Optional[dict] = None permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
class GroupMember(Base):
__tablename__ = "group_member"
id = Column(Text, unique=True, primary_key=True)
group_id = Column(
Text,
ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=True)
updated_at = Column(BigInteger, nullable=True)
class GroupMemberModel(BaseModel):
id: str
group_id: str
user_id: str
created_at: Optional[int] = None # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch
#################### ####################
# Forms # Forms
#################### ####################
@ -72,7 +92,7 @@ class GroupResponse(BaseModel):
permissions: Optional[dict] = None permissions: Optional[dict] = None
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
user_ids: list[str] = [] member_count: Optional[int] = None
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
@ -81,13 +101,14 @@ class GroupForm(BaseModel):
name: str name: str
description: str description: str
permissions: Optional[dict] = None permissions: Optional[dict] = None
data: Optional[dict] = None
class UserIdsForm(BaseModel): class UserIdsForm(BaseModel):
user_ids: Optional[list[str]] = None user_ids: Optional[list[str]] = None
class GroupUpdateForm(GroupForm, UserIdsForm): class GroupUpdateForm(GroupForm):
pass pass
@ -131,12 +152,8 @@ class GroupTable:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group) for group in db.query(Group)
.filter( .join(GroupMember, GroupMember.group_id == Group.id)
func.json_array_length(Group.user_ids) > 0 .filter(GroupMember.user_id == user_id)
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc()) .order_by(Group.updated_at.desc())
.all() .all()
] ]
@ -149,12 +166,46 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]: def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
group = self.get_group_by_id(id) with get_db() as db:
if group: members = (
return group.user_ids db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
else: )
return None
if not members:
return None
return [m[0] for m in members]
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
with get_db() as db:
# Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
# Insert new members
now = int(time.time())
new_members = [
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
)
for user_id in user_ids
]
db.add_all(new_members)
db.commit()
def get_group_member_count_by_id(self, id: str) -> int:
with get_db() as db:
count = (
db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id)
.scalar()
)
return count if count else 0
def update_group_by_id( def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
@ -195,20 +246,29 @@ class GroupTable:
def remove_user_from_all_groups(self, user_id: str) -> bool: def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
groups = self.get_groups_by_member_id(user_id) # Find all groups the user belongs to
groups = (
db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.all()
)
# Remove the user from each group
for group in groups: for group in groups:
group.user_ids.remove(user_id) db.query(GroupMember).filter(
db.query(Group).filter_by(id=group.id).update( GroupMember.group_id == group.id, GroupMember.user_id == user_id
{ ).delete()
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
db.query(Group).filter_by(id=group.id).update(
{"updated_at": int(time.time())}
)
db.commit()
return True return True
except Exception: except Exception:
db.rollback()
return False return False
def create_groups_by_group_names( def create_groups_by_group_names(
@ -246,37 +306,61 @@ class GroupTable:
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
with get_db() as db: with get_db() as db:
try: try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all() now = int(time.time())
group_ids = [group.id for group in groups]
# Remove user from groups not in the new list # 1. Groups that SHOULD contain the user
existing_groups = self.get_groups_by_member_id(user_id) target_groups = (
db.query(Group).filter(Group.name.in_(group_names)).all()
)
target_group_ids = {g.id for g in target_groups}
for group in existing_groups: # 2. Groups the user is CURRENTLY in
if group.id not in group_ids: existing_group_ids = {
group.user_ids.remove(user_id) g.id
db.query(Group).filter_by(id=group.id).update( for g in db.query(Group)
{ .join(GroupMember, GroupMember.group_id == Group.id)
"user_ids": group.user_ids, .filter(GroupMember.user_id == user_id)
"updated_at": int(time.time()), .all()
} }
# 3. Determine adds + removals
groups_to_add = target_group_ids - existing_group_ids
groups_to_remove = existing_group_ids - target_group_ids
# 4. Remove in one bulk delete
if groups_to_remove:
db.query(GroupMember).filter(
GroupMember.user_id == user_id,
GroupMember.group_id.in_(groups_to_remove),
).delete(synchronize_session=False)
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
{"updated_at": now}, synchronize_session=False
)
# 5. Bulk insert missing memberships
for group_id in groups_to_add:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
) )
)
# Add user to new groups if groups_to_add:
for group in groups: db.query(Group).filter(Group.id.in_(groups_to_add)).update(
if user_id not in group.user_ids: {"updated_at": now}, synchronize_session=False
group.user_ids.append(user_id) )
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit() db.commit()
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
db.rollback()
return False return False
def add_users_to_group( def add_users_to_group(
@ -288,21 +372,31 @@ class GroupTable:
if not group: if not group:
return None return None
group_user_ids = group.user_ids now = int(time.time())
if not group_user_ids or not isinstance(group_user_ids, list):
group_user_ids = []
group_user_ids = list(set(group_user_ids)) # Deduplicate for user_id in user_ids or []:
try:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=id,
user_id=user_id,
created_at=now,
updated_at=now,
)
)
db.flush() # Detect unique constraint violation early
except Exception:
db.rollback() # Clear failed INSERT
db.begin() # Start a new transaction
continue # Duplicate → ignore
for user_id in user_ids: group.updated_at = now
if user_id not in group_user_ids:
group_user_ids.append(user_id)
group.user_ids = group_user_ids
group.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(group) db.refresh(group)
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
@ -316,23 +410,22 @@ class GroupTable:
if not group: if not group:
return None return None
group_user_ids = group.user_ids if not user_ids:
if not group_user_ids or not isinstance(group_user_ids, list):
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
group_user_ids = list(set(group_user_ids)) # Deduplicate # Remove each user from group_member
for user_id in user_ids: for user_id in user_ids:
if user_id in group_user_ids: db.query(GroupMember).filter(
group_user_ids.remove(user_id) GroupMember.group_id == id, GroupMember.user_id == user_id
).delete()
group.user_ids = group_user_ids # Update group timestamp
group.updated_at = int(time.time()) group.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(group) db.refresh(group)
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None

View file

@ -14,7 +14,7 @@ from sqlalchemy import BigInteger, Column, String, Text
class Memory(Base): class Memory(Base):
__tablename__ = "memory" __tablename__ = "memory"
id = Column(String, primary_key=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
content = Column(Text) content = Column(Text)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)

View file

@ -20,7 +20,7 @@ from sqlalchemy.sql import exists
class MessageReaction(Base): class MessageReaction(Base):
__tablename__ = "message_reaction" __tablename__ = "message_reaction"
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
message_id = Column(Text) message_id = Column(Text)
name = Column(Text) name = Column(Text)

View file

@ -6,12 +6,12 @@ from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse from open_webui.models.users import User, UserModel, Users, UserResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import or_, and_, func from sqlalchemy import String, cast, or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
@ -133,6 +133,11 @@ class ModelResponse(ModelModel):
pass pass
class ModelListResponse(BaseModel):
items: list[ModelUserResponse]
total: int
class ModelForm(BaseModel): class ModelForm(BaseModel):
id: str id: str
base_model_id: Optional[str] = None base_model_id: Optional[str] = None
@ -215,6 +220,84 @@ class ModelsTable:
or has_access(user_id, permission, model.access_control, user_group_ids) or has_access(user_id, permission, model.access_control, user_group_ids)
] ]
def search_models(
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
) -> ModelListResponse:
with get_db() as db:
# Join GroupMember so we can order by group_id when requested
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
query = query.filter(Model.base_model_id != None)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Model.name.ilike(f"%{query_key}%"),
Model.base_model_id.ilike(f"%{query_key}%"),
)
)
if filter.get("user_id"):
query = query.filter(Model.user_id == filter.get("user_id"))
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Model.user_id == user_id)
elif view_option == "shared":
query = query.filter(Model.user_id != user_id)
tag = filter.get("tag")
if tag:
# TODO: This is a simple implementation and should be improved for performance
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
meta_text = func.lower(cast(Model.meta, String))
query = query.filter(meta_text.like(like_pattern))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(Model.name.asc())
else:
query = query.order_by(Model.name.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(Model.created_at.asc())
else:
query = query.order_by(Model.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Model.updated_at.asc())
else:
query = query.order_by(Model.updated_at.desc())
else:
query = query.order_by(Model.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
models = []
for model, user in items:
model_model = ModelModel.model_validate(model)
user_model = UserResponse(**UserModel.model_validate(user).model_dump())
models.append(
ModelUserResponse(**model_model.model_dump(), user=user_model)
)
return ModelListResponse(items=models, total=total)
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db() as db:
@ -244,11 +327,9 @@ class ModelsTable:
try: try:
with get_db() as db: with get_db() as db:
# update only the fields that are present in the model # update only the fields that are present in the model
result = ( data = model.model_dump(exclude={"id"})
db.query(Model) result = db.query(Model).filter_by(id=id).update(data)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}))
)
db.commit() db.commit()
model = db.get(Model, id) model = db.get(Model, id)

View file

@ -6,7 +6,7 @@ from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.groups import Groups from open_webui.models.groups import Groups, GroupMember
from open_webui.utils.misc import throttle from open_webui.utils.misc import throttle
@ -95,8 +95,12 @@ class UpdateProfileForm(BaseModel):
date_of_birth: Optional[datetime.date] = None date_of_birth: Optional[datetime.date] = None
class UserGroupIdsModel(UserModel):
group_ids: list[str] = []
class UserListResponse(BaseModel): class UserListResponse(BaseModel):
users: list[UserModel] users: list[UserGroupIdsModel]
total: int total: int
@ -222,7 +226,10 @@ class UsersTable:
limit: Optional[int] = None, limit: Optional[int] = None,
) -> dict: ) -> dict:
with get_db() as db: with get_db() as db:
query = db.query(User) # Join GroupMember so we can order by group_id when requested
query = db.query(User).outerjoin(
GroupMember, GroupMember.user_id == User.id
)
if filter: if filter:
query_key = filter.get("query") query_key = filter.get("query")
@ -237,7 +244,16 @@ class UsersTable:
order_by = filter.get("order_by") order_by = filter.get("order_by")
direction = filter.get("direction") direction = filter.get("direction")
if order_by == "name": if order_by and order_by.startswith("group_id:"):
group_id = order_by.split(":", 1)[1]
if direction == "asc":
query = query.order_by((GroupMember.group_id == group_id).asc())
else:
query = query.order_by(
(GroupMember.group_id == group_id).desc()
)
elif order_by == "name":
if direction == "asc": if direction == "asc":
query = query.order_by(User.name.asc()) query = query.order_by(User.name.asc())
else: else:
@ -274,6 +290,9 @@ class UsersTable:
else: else:
query = query.order_by(User.created_at.desc()) query = query.order_by(User.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip: if skip:
query = query.offset(skip) query = query.offset(skip)
if limit: if limit:
@ -282,7 +301,7 @@ class UsersTable:
users = query.all() users = query.all()
return { return {
"users": [UserModel.model_validate(user) for user in users], "users": [UserModel.model_validate(user) for user in users],
"total": db.query(User).count(), "total": total,
} }
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]: def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
@ -322,6 +341,15 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_num_users_active_today(self) -> Optional[int]:
with get_db() as db:
current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
query = db.query(User).filter(
User.last_active_at > today_midnight_timestamp
)
return query.count()
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:

View file

@ -33,13 +33,14 @@ class MinerULoader:
self.api_key = api_key self.api_key = api_key
# Parse params dict with defaults # Parse params dict with defaults
params = params or {} self.params = params or {}
self.enable_ocr = params.get("enable_ocr", False) self.enable_ocr = params.get("enable_ocr", False)
self.enable_formula = params.get("enable_formula", True) self.enable_formula = params.get("enable_formula", True)
self.enable_table = params.get("enable_table", True) self.enable_table = params.get("enable_table", True)
self.language = params.get("language", "en") self.language = params.get("language", "en")
self.model_version = params.get("model_version", "pipeline") self.model_version = params.get("model_version", "pipeline")
self.page_ranges = params.get("page_ranges", "")
self.page_ranges = self.params.pop("page_ranges", "")
# Validate API mode # Validate API mode
if self.api_mode not in ["local", "cloud"]: if self.api_mode not in ["local", "cloud"]:
@ -76,27 +77,10 @@ class MinerULoader:
# Build form data for Local API # Build form data for Local API
form_data = { form_data = {
**self.params,
"return_md": "true", "return_md": "true",
"formula_enable": str(self.enable_formula).lower(),
"table_enable": str(self.enable_table).lower(),
} }
# Parse method based on OCR setting
if self.enable_ocr:
form_data["parse_method"] = "ocr"
else:
form_data["parse_method"] = "auto"
# Language configuration (Local API uses lang_list array)
if self.language:
form_data["lang_list"] = self.language
# Backend/model version (Local API uses "backend" parameter)
if self.model_version == "vlm":
form_data["backend"] = "vlm-vllm-engine"
else:
form_data["backend"] = "pipeline"
# Page ranges (Local API uses start_page_id and end_page_id) # Page ranges (Local API uses start_page_id and end_page_id)
if self.page_ranges: if self.page_ranges:
# For simplicity, if page_ranges is specified, log a warning # For simplicity, if page_ranges is specified, log a warning
@ -236,10 +220,7 @@ class MinerULoader:
# Build request body # Build request body
request_body = { request_body = {
"enable_formula": self.enable_formula, **self.params,
"enable_table": self.enable_table,
"language": self.language,
"model_version": self.model_version,
"files": [ "files": [
{ {
"name": filename, "name": filename,

View file

@ -6,6 +6,7 @@ from urllib.parse import quote
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker from open_webui.retrieval.models.base_reranker import BaseReranker
from open_webui.utils.headers import include_user_info_headers
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -40,22 +41,17 @@ class ExternalReranker(BaseReranker):
log.info(f"ExternalReranker:predict:model {self.model}") log.info(f"ExternalReranker:predict:model {self.model}")
log.info(f"ExternalReranker:predict:query {query}") log.info(f"ExternalReranker:predict:query {query}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post( r = requests.post(
f"{self.url}", f"{self.url}",
headers={ headers=headers,
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=payload, json=payload,
) )

View file

@ -1,8 +1,10 @@
import logging import logging
import os import os
from typing import Optional, Union from typing import Awaitable, Optional, Union
import requests import requests
import aiohttp
import asyncio
import hashlib import hashlib
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import time import time
@ -27,6 +29,7 @@ from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.misc import get_message_list from open_webui.utils.misc import get_message_list
from open_webui.retrieval.web.utils import get_web_loader from open_webui.retrieval.web.utils import get_web_loader
@ -87,15 +90,16 @@ class VectorSearchRetriever(BaseRetriever):
embedding_function: Any embedding_function: Any
top_k: int top_k: int
def _get_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]: ) -> list[Document]:
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
result = VECTOR_DB_CLIENT.search( result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)], vectors=[embedding],
limit=self.top_k, limit=self.top_k,
) )
@ -148,7 +152,45 @@ def get_doc(collection_name: str, user: UserModel = None):
raise e raise e
def query_doc_with_hybrid_search( def get_enriched_texts(collection_result: GetResult) -> list[str]:
enriched_texts = []
for idx, text in enumerate(collection_result.documents[0]):
metadata = collection_result.metadatas[0][idx]
metadata_parts = [text]
# Add filename (repeat twice for extra weight in BM25 scoring)
if metadata.get("name"):
filename = metadata["name"]
filename_tokens = (
filename.replace("_", " ").replace("-", " ").replace(".", " ")
)
metadata_parts.append(
f"Filename: {filename} {filename_tokens} {filename_tokens}"
)
# Add title if available
if metadata.get("title"):
metadata_parts.append(f"Title: {metadata['title']}")
# Add document section headings if available (from markdown splitter)
if metadata.get("headings") and isinstance(metadata["headings"], list):
headings = " > ".join(str(h) for h in metadata["headings"])
metadata_parts.append(f"Section: {headings}")
# Add source URL/path if available
if metadata.get("source"):
metadata_parts.append(f"Source: {metadata['source']}")
# Add snippet for web search results
if metadata.get("snippet"):
metadata_parts.append(f"Snippet: {metadata['snippet']}")
enriched_texts.append(" ".join(metadata_parts))
return enriched_texts
async def query_doc_with_hybrid_search(
collection_name: str, collection_name: str,
collection_result: GetResult, collection_result: GetResult,
query: str, query: str,
@ -158,12 +200,21 @@ def query_doc_with_hybrid_search(
k_reranker: int, k_reranker: int,
r: float, r: float,
hybrid_bm25_weight: float, hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
) -> dict: ) -> dict:
try: try:
# First check if collection_result has the required attributes
if ( if (
not collection_result not collection_result
or not hasattr(collection_result, "documents") or not hasattr(collection_result, "documents")
or not collection_result.documents or not hasattr(collection_result, "metadatas")
):
log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
return {"documents": [], "metadatas": [], "distances": []}
# Now safely check the documents content after confirming attributes exist
if (
not collection_result.documents
or len(collection_result.documents) == 0 or len(collection_result.documents) == 0
or not collection_result.documents[0] or not collection_result.documents[0]
): ):
@ -172,8 +223,14 @@ def query_doc_with_hybrid_search(
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
bm25_texts = (
get_enriched_texts(collection_result)
if enable_enriched_texts
else collection_result.documents[0]
)
bm25_retriever = BM25Retriever.from_texts( bm25_retriever = BM25Retriever.from_texts(
texts=collection_result.documents[0], texts=bm25_texts,
metadatas=collection_result.metadatas[0], metadatas=collection_result.metadatas[0],
) )
bm25_retriever.k = k bm25_retriever.k = k
@ -209,7 +266,7 @@ def query_doc_with_hybrid_search(
base_compressor=compressor, base_retriever=ensemble_retriever base_compressor=compressor, base_retriever=ensemble_retriever
) )
result = compression_retriever.invoke(query) result = await compression_retriever.ainvoke(query)
distances = [d.metadata.get("score") for d in result] distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result] documents = [d.page_content for d in result]
@ -328,7 +385,7 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
return merge_get_results(results) return merge_get_results(results)
def query_collection( async def query_collection(
collection_names: list[str], collection_names: list[str],
queries: list[str], queries: list[str],
embedding_function, embedding_function,
@ -353,7 +410,9 @@ def query_collection(
return None, e return None, e
# Generate all query embeddings (in one call) # Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) query_embeddings = await embedding_function(
queries, prefix=RAG_EMBEDDING_QUERY_PREFIX
)
log.debug( log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
) )
@ -380,7 +439,7 @@ def query_collection(
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search( async def query_collection_with_hybrid_search(
collection_names: list[str], collection_names: list[str],
queries: list[str], queries: list[str],
embedding_function, embedding_function,
@ -389,6 +448,7 @@ def query_collection_with_hybrid_search(
k_reranker: int, k_reranker: int,
r: float, r: float,
hybrid_bm25_weight: float, hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
) -> dict: ) -> dict:
results = [] results = []
error = False error = False
@ -411,9 +471,9 @@ def query_collection_with_hybrid_search(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..." f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
) )
def process_query(collection_name, query): async def process_query(collection_name, query):
try: try:
result = query_doc_with_hybrid_search( result = await query_doc_with_hybrid_search(
collection_name=collection_name, collection_name=collection_name,
collection_result=collection_results[collection_name], collection_result=collection_results[collection_name],
query=query, query=query,
@ -423,6 +483,7 @@ def query_collection_with_hybrid_search(
k_reranker=k_reranker, k_reranker=k_reranker,
r=r, r=r,
hybrid_bm25_weight=hybrid_bm25_weight, hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=enable_enriched_texts,
) )
return result, None return result, None
except Exception as e: except Exception as e:
@ -432,15 +493,16 @@ def query_collection_with_hybrid_search(
# Prepare tasks for all collections and queries # Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None) # Avoid running any tasks for collections that failed to fetch data (have assigned None)
tasks = [ tasks = [
(cn, q) (collection_name, query)
for cn in collection_names for collection_name in collection_names
if collection_results[cn] is not None if collection_results[collection_name] is not None
for q in queries for query in queries
] ]
with ThreadPoolExecutor() as executor: # Run all queries in parallel using asyncio.gather
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks] task_results = await asyncio.gather(
task_results = [future.result() for future in future_results] *[process_query(collection_name, query) for collection_name, query in tasks]
)
for result, err in task_results: for result, err in task_results:
if err is not None: if err is not None:
@ -456,6 +518,248 @@ def query_collection_with_hybrid_search(
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
def generate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f"{url}/embeddings",
headers=headers,
json=json_data,
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
async def agenerate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
f"{url}/embeddings", headers=headers, json=form_data
) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
url,
headers=headers,
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
async def agenerate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
form_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(full_url, headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f"{url}/api/embed",
headers=headers,
json=json_data,
)
r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
async def agenerate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
f"{url}/api/embed", headers=headers, json=form_data
) as r:
r.raise_for_status()
data = await r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
def get_embedding_function( def get_embedding_function(
embedding_engine, embedding_engine,
embedding_model, embedding_model,
@ -464,13 +768,23 @@ def get_embedding_function(
key, key,
embedding_batch_size, embedding_batch_size,
azure_api_version=None, azure_api_version=None,
): ) -> Awaitable:
if embedding_engine == "": if embedding_engine == "":
return lambda query, prefix=None, user=None: embedding_function.encode( # Sentence transformers: CPU-bound sync operation
query, **({"prompt": prefix} if prefix else {}) async def async_embedding_function(query, prefix=None, user=None):
).tolist() return await asyncio.to_thread(
(
lambda query, prefix=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
),
query,
prefix,
)
return async_embedding_function
elif embedding_engine in ["ollama", "openai", "azure_openai"]: elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings( embedding_function = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine, engine=embedding_engine,
model=embedding_model, model=embedding_model,
text=query, text=query,
@ -481,41 +795,104 @@ def get_embedding_function(
azure_api_version=azure_api_version, azure_api_version=azure_api_version,
) )
def generate_multiple(query, prefix, user, func): async def async_embedding_function(query, prefix=None, user=None):
if isinstance(query, list): if isinstance(query, list):
embeddings = [] # Create batches
for i in range(0, len(query), embedding_batch_size): batches = [
batch_embeddings = func( query[i : i + embedding_batch_size]
query[i : i + embedding_batch_size], for i in range(0, len(query), embedding_batch_size)
prefix=prefix, ]
user=user, log.debug(
) f"generate_multiple_async: Processing {len(batches)} batches in parallel"
)
# Execute all batches in parallel
tasks = [
embedding_function(batch, prefix=prefix, user=user)
for batch in batches
]
batch_results = await asyncio.gather(*tasks)
# Flatten results
embeddings = []
for batch_embeddings in batch_results:
if isinstance(batch_embeddings, list): if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings) embeddings.extend(batch_embeddings)
log.debug(
f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches"
)
return embeddings return embeddings
else: else:
return func(query, prefix, user) return await embedding_function(query, prefix, user)
return lambda query, prefix=None, user=None: generate_multiple( return async_embedding_function
query, prefix, user, func
)
else: else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}") raise ValueError(f"Unknown embedding engine: {embedding_engine}")
async def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
text = [f"{prefix}{text_element}" for text_element in text]
else:
text = f"{prefix}{text}"
if engine == "ollama":
embeddings = await agenerate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
embeddings = await agenerate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = await agenerate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
def get_reranking_function(reranking_engine, reranking_model, reranking_function): def get_reranking_function(reranking_engine, reranking_model, reranking_function):
if reranking_function is None: if reranking_function is None:
return None return None
if reranking_engine == "external": if reranking_engine == "external":
return lambda sentences, user=None: reranking_function.predict( return lambda query, documents, user=None: reranking_function.predict(
sentences, user=user [(query, doc.page_content) for doc in documents], user=user
) )
else: else:
return lambda sentences, user=None: reranking_function.predict(sentences) return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
def get_sources_from_items( async def get_sources_from_items(
request, request,
items, items,
queries, queries,
@ -743,7 +1120,7 @@ def get_sources_from_items(
query_result = None # Initialize to None query_result = None # Initialize to None
if hybrid_search: if hybrid_search:
try: try:
query_result = query_collection_with_hybrid_search( query_result = await query_collection_with_hybrid_search(
collection_names=collection_names, collection_names=collection_names,
queries=queries, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
@ -752,6 +1129,7 @@ def get_sources_from_items(
k_reranker=k_reranker, k_reranker=k_reranker,
r=r, r=r,
hybrid_bm25_weight=hybrid_bm25_weight, hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
) )
except Exception as e: except Exception as e:
log.debug( log.debug(
@ -760,7 +1138,7 @@ def get_sources_from_items(
# fallback to non-hybrid search # fallback to non-hybrid search
if not hybrid_search and query_result is None: if not hybrid_search and query_result is None:
query_result = query_collection( query_result = await query_collection(
collection_names=collection_names, collection_names=collection_names,
queries=queries, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
@ -836,199 +1214,6 @@ def get_model_path(model: str, update_model: bool = False):
return model return model
def generate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
r = requests.post(
f"{url}/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
r = requests.post(
url,
headers={
"Content-Type": "application/json",
"api-key": key,
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
json=json_data,
)
r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
text = [f"{prefix}{text_element}" for text_element in text]
else:
text = f"{prefix}{text}"
if engine == "ollama":
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
embeddings = generate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = generate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
import operator import operator
from typing import Optional, Sequence from typing import Optional, Sequence
@ -1046,7 +1231,7 @@ class RerankCompressor(BaseDocumentCompressor):
extra = "forbid" extra = "forbid"
arbitrary_types_allowed = True arbitrary_types_allowed = True
def compress_documents( async def acompress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],
query: str, query: str,
@ -1062,8 +1247,10 @@ class RerankCompressor(BaseDocumentCompressor):
else: else:
from sentence_transformers import util from sentence_transformers import util
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) query_embedding = await self.embedding_function(
document_embedding = self.embedding_function( query, RAG_EMBEDDING_QUERY_PREFIX
)
document_embedding = await self.embedding_function(
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
) )
scores = util.cos_sim(query_embedding, document_embedding)[0] scores = util.cos_sim(query_embedding, document_embedding)[0]

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, Tuple
import logging import logging
import json import json
from sqlalchemy import ( from sqlalchemy import (
@ -22,7 +22,7 @@ from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector, HALFVEC
from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError from sqlalchemy.exc import NoSuchTableError
@ -44,11 +44,20 @@ from open_webui.config import (
PGVECTOR_POOL_MAX_OVERFLOW, PGVECTOR_POOL_MAX_OVERFLOW,
PGVECTOR_POOL_TIMEOUT, PGVECTOR_POOL_TIMEOUT,
PGVECTOR_POOL_RECYCLE, PGVECTOR_POOL_RECYCLE,
PGVECTOR_INDEX_METHOD,
PGVECTOR_HNSW_M,
PGVECTOR_HNSW_EF_CONSTRUCTION,
PGVECTOR_IVFFLAT_LISTS,
PGVECTOR_USE_HALFVEC,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
USE_HALFVEC = PGVECTOR_USE_HALFVEC
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
Base = declarative_base() Base = declarative_base()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -67,7 +76,7 @@ class DocumentChunk(Base):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False) collection_name = Column(Text, nullable=False)
if PGVECTOR_PGCRYPTO: if PGVECTOR_PGCRYPTO:
@ -157,13 +166,9 @@ class PgvectorClient(VectorDBBase):
connection = self.session.connection() connection = self.session.connection()
Base.metadata.create_all(bind=connection) Base.metadata.create_all(bind=connection)
# Create an index on the vector column if it doesn't exist index_method, index_options = self._vector_index_configuration()
self.session.execute( self._ensure_vector_index(index_method, index_options)
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
)
)
self.session.execute( self.session.execute(
text( text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
@ -177,6 +182,78 @@ class PgvectorClient(VectorDBBase):
log.exception(f"Error during initialization: {e}") log.exception(f"Error during initialization: {e}")
raise raise
@staticmethod
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
if not index_def:
return None
try:
after_using = index_def.lower().split("using ", 1)[1]
return after_using.split()[0]
except (IndexError, AttributeError):
return None
def _vector_index_configuration(self) -> Tuple[str, str]:
if PGVECTOR_INDEX_METHOD:
index_method = PGVECTOR_INDEX_METHOD
log.info(
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
index_method,
)
elif USE_HALFVEC:
index_method = "hnsw"
log.info(
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
VECTOR_LENGTH,
)
else:
index_method = "ivfflat"
if index_method == "hnsw":
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
else:
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
return index_method, index_options
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
index_name = "idx_document_chunk_vector"
existing_index_def = self.session.execute(
text(
"""
SELECT indexdef
FROM pg_indexes
WHERE schemaname = current_schema()
AND tablename = 'document_chunk'
AND indexname = :index_name
"""
),
{"index_name": index_name},
).scalar()
existing_method = self._extract_index_method(existing_index_def)
if existing_method and existing_method != index_method:
raise RuntimeError(
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
"and recreate it with the new method before restarting Open WebUI."
)
if not existing_index_def:
index_sql = (
f"CREATE INDEX IF NOT EXISTS {index_name} "
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
)
if index_options:
index_sql = f"{index_sql} {index_options}"
self.session.execute(text(index_sql))
log.info(
"Ensured vector index '%s' using %s%s.",
index_name,
index_method,
f" {index_options}" if index_options else "",
)
def check_vector_length(self) -> None: def check_vector_length(self) -> None:
""" """
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
@ -196,16 +273,19 @@ class PgvectorClient(VectorDBBase):
if "vector" in document_chunk_table.columns: if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"] vector_column = document_chunk_table.columns["vector"]
vector_type = vector_column.type vector_type = vector_column.type
if isinstance(vector_type, Vector): expected_type = HALFVEC if USE_HALFVEC else Vector
db_vector_length = vector_type.dim
if db_vector_length != VECTOR_LENGTH: if not isinstance(vector_type, expected_type):
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
)
else:
raise Exception( raise Exception(
"The 'vector' column exists but is not of type 'Vector'." "The 'vector' column type does not match the expected type "
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
)
db_vector_length = getattr(vector_type, "dim", None)
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
) )
else: else:
raise Exception( raise Exception(
@ -360,11 +440,11 @@ class PgvectorClient(VectorDBBase):
num_queries = len(vectors) num_queries = len(vectors)
def vector_expr(vector): def vector_expr(vector):
return cast(array(vector), Vector(VECTOR_LENGTH)) return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
# Create the values for query vectors # Create the values for query vectors
qid_col = column("qid", Integer) qid_col = column("qid", Integer)
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
query_vectors = ( query_vectors = (
values(qid_col, q_vector_col) values(qid_col, q_vector_col)
.data( .data(

View file

@ -117,15 +117,16 @@ class S3VectorClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
""" """
Check if a vector index (collection) exists in the S3 vector bucket. Check if a vector index exists using direct lookup.
This avoids pagination issues with list_indexes() and is significantly faster.
""" """
try: try:
response = self.client.list_indexes(vectorBucketName=self.bucket_name) self.client.get_index(
indexes = response.get("indexes", []) vectorBucketName=self.bucket_name, indexName=collection_name
return any(idx.get("indexName") == collection_name for idx in indexes) )
return True
except Exception as e: except Exception as e:
log.error(f"Error listing indexes: {e}") log.error(f"Error checking if index '{collection_name}' exists: {e}")
return False return False
def delete_collection(self, collection_name: str) -> None: def delete_collection(self, collection_name: str) -> None:

View file

@ -0,0 +1,340 @@
import weaviate
import re
import uuid
from typing import Any, Dict, List, Optional, Union
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.config import (
WEAVIATE_HTTP_HOST,
WEAVIATE_HTTP_PORT,
WEAVIATE_GRPC_PORT,
WEAVIATE_API_KEY,
)
def _convert_uuids_to_strings(obj: Any) -> Any:
"""
Recursively convert UUID objects to strings in nested data structures.
This function handles:
- UUID objects -> string
- Dictionaries with UUID values
- Lists/Tuples with UUID values
- Nested combinations of the above
Args:
obj: Any object that might contain UUIDs
Returns:
The same object structure with UUIDs converted to strings
"""
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, dict):
return {key: _convert_uuids_to_strings(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_convert_uuids_to_strings(item) for item in obj)
elif isinstance(obj, (str, int, float, bool, type(None))):
return obj
else:
return obj
class WeaviateClient(VectorDBBase):
def __init__(self):
self.url = WEAVIATE_HTTP_HOST
try:
# Build connection parameters
connection_params = {
"host": WEAVIATE_HTTP_HOST,
"port": WEAVIATE_HTTP_PORT,
"grpc_port": WEAVIATE_GRPC_PORT,
}
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
if WEAVIATE_API_KEY:
connection_params["auth_credentials"] = (
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
)
self.client = weaviate.connect_to_local(**connection_params)
self.client.connect()
except Exception as e:
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
def _sanitize_collection_name(self, collection_name: str) -> str:
"""Sanitize collection name to be a valid Weaviate class name."""
if not isinstance(collection_name, str) or not collection_name.strip():
raise ValueError("Collection name must be a non-empty string")
# Requirements for a valid Weaviate class name:
# The collection name must begin with a capital letter.
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
# Replace hyphens with underscores and keep only alphanumeric characters
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
name = name.strip("_")
if not name:
raise ValueError(
"Could not sanitize collection name to be a valid Weaviate class name"
)
# Ensure it starts with a letter and is capitalized
if not name[0].isalpha():
name = "C" + name
return name[0].upper() + name[1:]
def has_collection(self, collection_name: str) -> bool:
sane_collection_name = self._sanitize_collection_name(collection_name)
return self.client.collections.exists(sane_collection_name)
def delete_collection(self, collection_name: str) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if self.client.collections.exists(sane_collection_name):
self.client.collections.delete(sane_collection_name)
def _create_collection(self, collection_name: str) -> None:
self.client.collections.create(
name=collection_name,
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
properties=[
weaviate.classes.config.Property(
name="text", data_type=weaviate.classes.config.DataType.TEXT
),
],
)
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
self._create_collection(sane_collection_name)
collection = self.client.collections.get(sane_collection_name)
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
self._create_collection(sane_collection_name)
collection = self.client.collections.get(sane_collection_name)
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(item["id"]) if item["id"] else None
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
result_ids, result_documents, result_metadatas, result_distances = (
[],
[],
[],
[],
)
for vector_embedding in vectors:
try:
response = collection.query.near_vector(
near_vector=vector_embedding,
limit=limit,
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
metadatas = []
distances = []
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
raw_distances = [
(
obj.metadata.distance
if obj.metadata and obj.metadata.distance
else 2.0
)
for obj in response.objects
]
distances = [(2 - dist) / 2 for dist in raw_distances]
result_ids.append(ids)
result_documents.append(documents)
result_metadatas.append(metadatas)
result_distances.append(distances)
except Exception:
result_ids.append([])
result_documents.append([])
result_metadatas.append([])
result_distances.append([])
return SearchResult(
**{
"ids": result_ids,
"documents": result_documents,
"metadatas": result_metadatas,
"distances": result_distances,
}
)
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
weaviate_filter = None
if filter:
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
value
)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
)
try:
response = collection.query.fetch_objects(
filters=weaviate_filter, limit=limit
)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
metadatas = []
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
except Exception:
return None
def get(self, collection_name: str) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
ids, documents, metadatas = [], [], []
try:
for item in collection.iterator():
ids.append(str(item.uuid))
properties = dict(item.properties) if item.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
if not ids:
return None
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
except Exception:
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return
collection = self.client.collections.get(sane_collection_name)
try:
if ids:
for item_id in ids:
collection.data.delete_by_id(uuid=item_id)
elif filter:
weaviate_filter = None
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(
name=key
).equal(value)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
)
if weaviate_filter:
collection.data.delete_many(where=weaviate_filter)
except Exception:
pass
def reset(self) -> None:
try:
for collection_name in self.client.collections.list_all().keys():
self.client.collections.delete(collection_name)
except Exception:
pass

View file

@ -67,6 +67,10 @@ class Vector:
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
return Oracle23aiClient() return Oracle23aiClient()
case VectorType.WEAVIATE:
from open_webui.retrieval.vector.dbs.weaviate import WeaviateClient
return WeaviateClient()
case _: case _:
raise ValueError(f"Unsupported vector type: {vector_type}") raise ValueError(f"Unsupported vector type: {vector_type}")

View file

@ -11,3 +11,4 @@ class VectorType(StrEnum):
PGVECTOR = "pgvector" PGVECTOR = "pgvector"
ORACLE23AI = "oracle23ai" ORACLE23AI = "oracle23ai"
S3VECTOR = "s3vector" S3VECTOR = "s3vector"
WEAVIATE = "weaviate"

View file

@ -0,0 +1,128 @@
import logging
from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
"""
Azure AI Search integration for Open WebUI.
Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python
Required package: azure-search-documents
Install: pip install azure-search-documents
"""
def search_azure(
api_key: str,
endpoint: str,
index_name: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""
Search using Azure AI Search.
Args:
api_key: Azure Search API key (query key or admin key)
endpoint: Azure Search service endpoint (e.g., https://myservice.search.windows.net)
index_name: Name of the search index to query
query: Search query string
count: Number of results to return
filter_list: Optional list of domains to filter results
Returns:
List of SearchResult objects with link, title, and snippet
"""
try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
except ImportError:
log.error(
"azure-search-documents package is not installed. "
"Install it with: pip install azure-search-documents"
)
raise ImportError(
"azure-search-documents is required for Azure AI Search. "
"Install it with: pip install azure-search-documents"
)
try:
# Create search client with API key authentication
credential = AzureKeyCredential(api_key)
search_client = SearchClient(
endpoint=endpoint, index_name=index_name, credential=credential
)
# Perform the search
results = search_client.search(search_text=query, top=count)
# Convert results to list and extract fields
search_results = []
for result in results:
# Azure AI Search returns documents with custom schemas
# We need to extract common fields that might represent URL, title, and content
# Common field names to look for:
result_dict = dict(result)
# Try to find URL field (common names)
link = (
result_dict.get("url")
or result_dict.get("link")
or result_dict.get("uri")
or result_dict.get("metadata_storage_path")
or ""
)
# Try to find title field (common names)
title = (
result_dict.get("title")
or result_dict.get("name")
or result_dict.get("metadata_title")
or result_dict.get("metadata_storage_name")
or None
)
# Try to find content/snippet field (common names)
snippet = (
result_dict.get("content")
or result_dict.get("snippet")
or result_dict.get("description")
or result_dict.get("summary")
or result_dict.get("text")
or None
)
# Truncate snippet if too long
if snippet and len(snippet) > 500:
snippet = snippet[:497] + "..."
if link: # Only add if we found a valid link
search_results.append(
{
"link": link,
"title": title,
"snippet": snippet,
}
)
# Apply domain filtering if specified
if filter_list:
search_results = get_filtered_results(search_results, filter_list)
# Convert to SearchResult objects
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in search_results
]
except Exception as ex:
log.error(f"Azure AI Search error: {ex}")
raise ex

View file

@ -2,27 +2,42 @@ import logging
from typing import Optional, List from typing import Optional, List
import requests import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from fastapi import Request
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.utils.headers import include_user_info_headers
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_external( def search_external(
request: Request,
external_url: str, external_url: str,
external_api_key: str, external_api_key: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[List[str]] = None,
user=None,
) -> List[SearchResult]: ) -> List[SearchResult]:
try: try:
headers = {
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
}
headers = include_user_info_headers(headers, user)
chat_id = getattr(request.state, "chat_id", None)
if chat_id:
headers["X-OpenWebUI-Chat-Id"] = str(chat_id)
response = requests.post( response = requests.post(
external_url, external_url,
headers={ headers=headers,
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
},
json={ json={
"query": query, "query": query,
"count": count, "count": count,

View file

@ -5,18 +5,37 @@ from urllib.parse import urlparse
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname
def get_filtered_results(results, filter_list): def get_filtered_results(results, filter_list):
if not filter_list: if not filter_list:
return results return results
filtered_results = [] filtered_results = []
for result in results: for result in results:
url = result.get("url") or result.get("link", "") or result.get("href", "") url = result.get("url") or result.get("link", "") or result.get("href", "")
if not validators.url(url): if not validators.url(url):
continue continue
domain = urlparse(url).netloc domain = urlparse(url).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list): if not domain:
continue
hostnames = [domain]
try:
ipv4_addresses, ipv6_addresses = resolve_hostname(domain)
hostnames.extend(ipv4_addresses)
hostnames.extend(ipv6_addresses)
except Exception:
pass
if any(is_string_allowed(hostname, filter_list) for hostname in hostnames):
filtered_results.append(result) filtered_results.append(result)
continue
return filtered_results return filtered_results

View file

@ -3,6 +3,7 @@ from typing import Optional, Literal
import requests import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.utils.headers import include_user_info_headers
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -15,6 +16,8 @@ def search_perplexity_search(
query: str, query: str,
count: int, count: int,
filter_list: Optional[list[str]] = None, filter_list: Optional[list[str]] = None,
api_url: str = "https://api.perplexity.ai/search",
user=None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects. """Search using Perplexity API and return the results as a list of SearchResult objects.
@ -23,6 +26,8 @@ def search_perplexity_search(
query (str): The query to search for query (str): The query to search for
count (int): Maximum number of results to return count (int): Maximum number of results to return
filter_list (Optional[list[str]]): List of domains to filter results filter_list (Optional[list[str]]): List of domains to filter results
api_url (str): Custom API URL (defaults to https://api.perplexity.ai/search)
user: Optional user object for forwarding user info headers
""" """
@ -30,8 +35,11 @@ def search_perplexity_search(
if hasattr(api_key, "__str__"): if hasattr(api_key, "__str__"):
api_key = str(api_key) api_key = str(api_key)
if hasattr(api_url, "__str__"):
api_url = str(api_url)
try: try:
url = "https://api.perplexity.ai/search" url = api_url
# Create payload for the API call # Create payload for the API call
payload = { payload = {
@ -44,6 +52,10 @@ def search_perplexity_search(
"Content-Type": "application/json", "Content-Type": "application/json",
} }
# Forward user info headers if user is provided
if user is not None:
headers = include_user_info_headers(headers, user)
# Make the API request # Make the API request
response = requests.request("POST", url, json=payload, headers=headers) response = requests.request("POST", url, json=payload, headers=headers)
# Parse the JSON response # Parse the JSON response

View file

@ -24,6 +24,7 @@ import validators
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
from langchain_community.document_loaders.base import BaseLoader from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.retrieval.loaders.external_web import ExternalWebLoader from open_webui.retrieval.loaders.external_web import ExternalWebLoader
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
@ -38,6 +39,7 @@ from open_webui.config import (
TAVILY_EXTRACT_DEPTH, TAVILY_EXTRACT_DEPTH,
EXTERNAL_WEB_LOADER_URL, EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY, EXTERNAL_WEB_LOADER_API_KEY,
WEB_FETCH_FILTER_LIST,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -46,10 +48,70 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def get_allow_block_lists(filter_list):
allow_list = []
block_list = []
if filter_list:
for d in filter_list:
if d.startswith("!"):
# Domains starting with "!" → blocked
block_list.append(d[1:])
else:
# Domains starting without "!" → allowed
allow_list.append(d)
return allow_list, block_list
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
if not filter_list:
return True
allow_list, block_list = get_allow_block_lists(filter_list)
# If allow list is non-empty, require domain to match one of them
if allow_list:
if not any(string.endswith(allowed) for allowed in allow_list):
return False
# Block list always removes matches
if any(string.endswith(blocked) for blocked in block_list):
return False
return True
def validate_url(url: Union[str, Sequence[str]]): def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str): if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError): if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
parsed_url = urllib.parse.urlparse(url)
# Protocol validation - only allow http/https
if parsed_url.scheme not in ["http", "https"]:
log.warning(
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
)
raise ValueError(ERROR_MESSAGES.INVALID_URL)
# Blocklist check using unified filtering logic
if WEB_FETCH_FILTER_LIST:
if not is_string_allowed(url, WEB_FETCH_FILTER_LIST):
log.warning(f"URL blocked by filter list: {url}")
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH: if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
@ -82,17 +144,6 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
return valid_urls return valid_urls
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url): def extract_metadata(soup, url):
metadata = {"source": url} metadata = {"source": url}
if title := soup.find("title"): if title := soup.find("title"):
@ -642,6 +693,10 @@ def get_web_loader(
# Check if the URLs are valid # Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
if not safe_urls:
log.warning(f"All provided URLs were blocked or invalid: {urls}")
raise ValueError(ERROR_MESSAGES.INVALID_URL)
web_loader_args = { web_loader_args = {
"web_paths": safe_urls, "web_paths": safe_urls,
"verify_ssl": verify_ssl, "verify_ssl": verify_ssl,

View file

@ -16,7 +16,6 @@ import aiohttp
import aiofiles import aiofiles
import requests import requests
import mimetypes import mimetypes
from urllib.parse import urljoin, quote
from fastapi import ( from fastapi import (
Depends, Depends,
@ -35,6 +34,7 @@ from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import include_user_info_headers
from open_webui.config import ( from open_webui.config import (
WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
@ -364,23 +364,17 @@ async def speech(request: Request, user=Depends(get_verified_user)):
**(request.app.state.config.TTS_OPENAI_PARAMS or {}), **(request.app.state.config.TTS_OPENAI_PARAMS or {}),
} }
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers = include_user_info_headers(headers, user)
r = await session.post( r = await session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload, json=payload,
headers={ headers=headers,
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) )
@ -570,7 +564,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
def transcription_handler(request, file_path, metadata): def transcription_handler(request, file_path, metadata, user=None):
filename = os.path.basename(file_path) filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
id = filename.split(".")[0] id = filename.split(".")[0]
@ -621,11 +615,15 @@ def transcription_handler(request, file_path, metadata):
if language: if language:
payload["language"] = language payload["language"] = language
headers = {
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
}
if user and ENABLE_FORWARD_USER_INFO_HEADERS:
headers = include_user_info_headers(headers, user)
r = requests.post( r = requests.post(
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers={ headers=headers,
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
},
files={"file": (filename, open(file_path, "rb"))}, files={"file": (filename, open(file_path, "rb"))},
data=payload, data=payload,
) )
@ -1027,7 +1025,9 @@ def transcription_handler(request, file_path, metadata):
) )
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None): def transcribe(
request: Request, file_path: str, metadata: Optional[dict] = None, user=None
):
log.info(f"transcribe: {file_path} {metadata}") log.info(f"transcribe: {file_path} {metadata}")
if is_audio_conversion_required(file_path): if is_audio_conversion_required(file_path):
@ -1054,7 +1054,9 @@ def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
# Submit tasks for each chunk_path # Submit tasks for each chunk_path
futures = [ futures = [
executor.submit(transcription_handler, request, chunk_path, metadata) executor.submit(
transcription_handler, request, chunk_path, metadata, user
)
for chunk_path in chunk_paths for chunk_path in chunk_paths
] ]
# Gather results as they complete # Gather results as they complete
@ -1189,7 +1191,7 @@ def transcription(
if language: if language:
metadata = {"language": language} metadata = {"language": language}
result = transcribe(request, file_path, metadata) result = transcribe(request, file_path, metadata, user)
return { return {
**result, **result,

View file

@ -4,6 +4,7 @@ import time
import datetime import datetime
import logging import logging
from aiohttp import ClientSession from aiohttp import ClientSession
import urllib
from open_webui.models.auths import ( from open_webui.models.auths import (
AddUserForm, AddUserForm,
@ -35,12 +36,20 @@ from open_webui.env import (
) )
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse, Response, JSONResponse from fastapi.responses import RedirectResponse, Response, JSONResponse
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP from open_webui.config import (
OPENID_PROVIDER_URL,
ENABLE_OAUTH_SIGNUP,
ENABLE_LDAP,
ENABLE_PASSWORD_AUTH,
)
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.auth import ( from open_webui.utils.auth import (
validate_password,
verify_password,
decode_token, decode_token,
invalidate_token,
create_api_key, create_api_key,
create_token, create_token,
get_admin_user, get_admin_user,
@ -50,7 +59,7 @@ from open_webui.utils.auth import (
get_http_authorization_cred, get_http_authorization_cred,
) )
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from open_webui.utils.access_control import get_permissions from open_webui.utils.access_control import get_permissions, has_permission
from typing import Optional, List from typing import Optional, List
@ -169,13 +178,19 @@ async def update_password(
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(
session_user.email, lambda pw: verify_password(form_data.password, pw)
)
if user: if user:
try:
validate_password(form_data.password)
except Exception as e:
raise HTTPException(400, detail=str(e))
hashed = get_password_hash(form_data.new_password) hashed = get_password_hash(form_data.new_password)
return Auths.update_user_password_by_id(user.id, hashed) return Auths.update_user_password_by_id(user.id, hashed)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@ -185,7 +200,17 @@ async def update_password(
############################ ############################
@router.post("/ldap", response_model=SessionUserResponse) @router.post("/ldap", response_model=SessionUserResponse)
async def ldap_auth(request: Request, response: Response, form_data: LdapForm): async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP # Security checks FIRST - before loading any config
if not request.app.state.config.ENABLE_LDAP:
raise HTTPException(400, detail="LDAP authentication is not enabled")
if not ENABLE_PASSWORD_AUTH:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
# NOW load LDAP config variables
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
@ -206,9 +231,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
else "ALL" else "ALL"
) )
if not ENABLE_LDAP:
raise HTTPException(400, detail="LDAP authentication is not enabled")
try: try:
tls = Tls( tls = Tls(
validate=LDAP_VALIDATE_CERT, validate=LDAP_VALIDATE_CERT,
@ -463,6 +485,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
@router.post("/signin", response_model=SessionUserResponse) @router.post("/signin", response_model=SessionUserResponse)
async def signin(request: Request, response: Response, form_data: SigninForm): async def signin(request: Request, response: Response, form_data: SigninForm):
if not ENABLE_PASSWORD_AUTH:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
@ -472,6 +500,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_NAME_HEADER: if WEBUI_AUTH_TRUSTED_NAME_HEADER:
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email) name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
try:
name = urllib.parse.unquote(name, encoding="utf-8")
except Exception as e:
pass
if not Users.get_user_by_email(email.lower()): if not Users.get_user_by_email(email.lower()):
await signup( await signup(
@ -495,7 +527,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
admin_password = "admin" admin_password = "admin"
if Users.get_user_by_email(admin_email.lower()): if Users.get_user_by_email(admin_email.lower()):
user = Auths.authenticate_user(admin_email.lower(), admin_password) user = Auths.authenticate_user(
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
)
else: else:
if Users.has_users(): if Users.has_users():
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
@ -506,7 +540,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
SignupForm(email=admin_email, password=admin_password, name="User"), SignupForm(email=admin_email, password=admin_password, name="User"),
) )
user = Auths.authenticate_user(admin_email.lower(), admin_password) user = Auths.authenticate_user(
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
)
else: else:
password_bytes = form_data.password.encode("utf-8") password_bytes = form_data.password.encode("utf-8")
if len(password_bytes) > 72: if len(password_bytes) > 72:
@ -517,7 +553,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
# decode safely — ignore incomplete UTF-8 sequences # decode safely — ignore incomplete UTF-8 sequences
form_data.password = password_bytes.decode("utf-8", errors="ignore") form_data.password = password_bytes.decode("utf-8", errors="ignore")
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) user = Auths.authenticate_user(
form_data.email.lower(), lambda pw: verify_password(form_data.password, pw)
)
if user: if user:
@ -599,16 +637,14 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE try:
validate_password(form_data.password)
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing. except Exception as e:
if len(form_data.password.encode("utf-8")) > 72: raise HTTPException(400, detail=str(e))
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
)
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
form_data.email.lower(), form_data.email.lower(),
hashed, hashed,
@ -664,6 +700,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
# Disable signup after the first user is created # Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False request.app.state.config.ENABLE_SIGNUP = False
default_group_id = getattr(request.app.state.config, "DEFAULT_GROUP_ID", "")
if default_group_id and default_group_id:
Groups.add_users_to_group(default_group_id, [user.id])
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
@ -684,6 +724,19 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout") @router.get("/signout")
async def signout(request: Request, response: Response): async def signout(request: Request, response: Response):
# get auth token from headers or cookies
token = None
auth_header = request.headers.get("Authorization")
if auth_header:
auth_cred = get_http_authorization_cred(auth_header)
token = auth_cred.credentials
else:
token = request.cookies.get("token")
if token:
await invalidate_token(request, token)
response.delete_cookie("token") response.delete_cookie("token")
response.delete_cookie("oui-session") response.delete_cookie("oui-session")
response.delete_cookie("oauth_id_token") response.delete_cookie("oauth_id_token")
@ -764,6 +817,11 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
try:
validate_password(form_data.password)
except Exception as e:
raise HTTPException(400, detail=str(e))
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
form_data.email.lower(), form_data.email.lower(),
@ -835,10 +893,11 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL, "WEBUI_URL": request.app.state.config.WEBUI_URL,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS, "API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
@ -855,10 +914,11 @@ class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool SHOW_ADMIN_DETAILS: bool
WEBUI_URL: str WEBUI_URL: str
ENABLE_SIGNUP: bool ENABLE_SIGNUP: bool
ENABLE_API_KEY: bool ENABLE_API_KEYS: bool
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS: bool
API_KEY_ALLOWED_ENDPOINTS: str API_KEYS_ALLOWED_ENDPOINTS: str
DEFAULT_USER_ROLE: str DEFAULT_USER_ROLE: str
DEFAULT_GROUP_ID: str
JWT_EXPIRES_IN: str JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool ENABLE_COMMUNITY_SHARING: bool
ENABLE_MESSAGE_RATING: bool ENABLE_MESSAGE_RATING: bool
@ -878,12 +938,12 @@ async def update_admin_config(
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY request.app.state.config.ENABLE_API_KEYS = form_data.ENABLE_API_KEYS
request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = ( request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = (
form_data.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS form_data.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS
) )
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS = ( request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS = (
form_data.API_KEY_ALLOWED_ENDPOINTS form_data.API_KEYS_ALLOWED_ENDPOINTS
) )
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
@ -892,6 +952,8 @@ async def update_admin_config(
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
request.app.state.config.DEFAULT_GROUP_ID = form_data.DEFAULT_GROUP_ID
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$" pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern # Check if the input string matches the pattern
@ -918,10 +980,11 @@ async def update_admin_config(
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL, "WEBUI_URL": request.app.state.config.WEBUI_URL,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS, "API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
@ -1045,9 +1108,11 @@ async def update_ldap_config(
# create api key # create api key
@router.post("/api_key", response_model=ApiKey) @router.post("/api_key", response_model=ApiKey)
async def generate_api_key(request: Request, user=Depends(get_current_user)): async def generate_api_key(request: Request, user=Depends(get_current_user)):
if not request.app.state.config.ENABLE_API_KEY: if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED, detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
) )

View file

@ -7,6 +7,7 @@ from open_webui.socket.main import get_event_emitter
from open_webui.models.chats import ( from open_webui.models.chats import (
ChatForm, ChatForm,
ChatImportForm, ChatImportForm,
ChatsImportForm,
ChatResponse, ChatResponse,
Chats, Chats,
ChatTitleIdResponse, ChatTitleIdResponse,
@ -142,26 +143,15 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
############################ ############################
# ImportChat # ImportChats
############################ ############################
@router.post("/import", response_model=Optional[ChatResponse]) @router.post("/import", response_model=list[ChatResponse])
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)): async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)):
try: try:
chat = Chats.import_chat(user.id, form_data) chats = Chats.import_chats(user.id, form_data.chats)
if chat: return chats
tags = chat.meta.get("tags", [])
for tag_id in tags:
tag_id = tag_id.replace(" ", "_").lower()
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
if (
tag_id != "none"
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
):
Tags.insert_new_tag(tag_name, user.id)
return ChatResponse(**chat.model_dump())
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -228,7 +218,7 @@ async def get_chat_list_by_folder_id(
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user) folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
): ):
try: try:
limit = 60 limit = 10
skip = (page - 1) * limit skip = (page - 1) * limit
return [ return [
@ -658,19 +648,28 @@ async def clone_chat_by_id(
"title": form_data.title if form_data.title else f"Clone of {chat.title}", "title": form_data.title if form_data.title else f"Clone of {chat.title}",
} }
chat = Chats.import_chat( chats = Chats.import_chats(
user.id, user.id,
ChatImportForm( [
**{ ChatImportForm(
"chat": updated_chat, **{
"meta": chat.meta, "chat": updated_chat,
"pinned": chat.pinned, "meta": chat.meta,
"folder_id": chat.folder_id, "pinned": chat.pinned,
} "folder_id": chat.folder_id,
), }
)
],
) )
return ChatResponse(**chat.model_dump()) if chats:
chat = chats[0]
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -698,18 +697,28 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
"title": f"Clone of {chat.title}", "title": f"Clone of {chat.title}",
} }
chat = Chats.import_chat( chats = Chats.import_chats(
user.id, user.id,
ChatImportForm( [
**{ ChatImportForm(
"chat": updated_chat, **{
"meta": chat.meta, "chat": updated_chat,
"pinned": chat.pinned, "meta": chat.meta,
"folder_id": chat.folder_id, "pinned": chat.pinned,
} "folder_id": chat.folder_id,
), }
)
],
) )
return ChatResponse(**chat.model_dump())
if chats:
chat = chats[0]
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()

View file

@ -144,6 +144,7 @@ class ToolServerConnection(BaseModel):
path: str path: str
type: Optional[str] = "openapi" # openapi, mcp type: Optional[str] = "openapi" # openapi, mcp
auth_type: Optional[str] auth_type: Optional[str]
headers: Optional[dict | str] = None
key: Optional[str] key: Optional[str]
config: Optional[dict] config: Optional[dict]
@ -270,18 +271,26 @@ async def verify_tool_servers_config(
elif form_data.auth_type == "session": elif form_data.auth_type == "session":
token = request.state.token.credentials token = request.state.token.credentials
elif form_data.auth_type == "system_oauth": elif form_data.auth_type == "system_oauth":
oauth_token = None
try: try:
if request.cookies.get("oauth_session_id", None): if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token( oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id, user.id,
request.cookies.get("oauth_session_id", None), request.cookies.get("oauth_session_id", None),
) )
if oauth_token:
token = oauth_token.get("access_token", "")
except Exception as e: except Exception as e:
pass pass
if token: if token:
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
if form_data.headers and isinstance(form_data.headers, dict):
if headers is None:
headers = {}
headers.update(form_data.headers)
await client.connect(form_data.url, headers=headers) await client.connect(form_data.url, headers=headers)
specs = await client.list_tool_specs() specs = await client.list_tool_specs()
return { return {
@ -299,6 +308,7 @@ async def verify_tool_servers_config(
await client.disconnect() await client.disconnect()
else: # openapi else: # openapi
token = None token = None
headers = None
if form_data.auth_type == "bearer": if form_data.auth_type == "bearer":
token = form_data.key token = form_data.key
elif form_data.auth_type == "session": elif form_data.auth_type == "session":
@ -306,15 +316,29 @@ async def verify_tool_servers_config(
elif form_data.auth_type == "system_oauth": elif form_data.auth_type == "system_oauth":
try: try:
if request.cookies.get("oauth_session_id", None): if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token( oauth_token = (
user.id, await request.app.state.oauth_manager.get_oauth_token(
request.cookies.get("oauth_session_id", None), user.id,
request.cookies.get("oauth_session_id", None),
)
) )
if oauth_token:
token = oauth_token.get("access_token", "")
except Exception as e: except Exception as e:
pass pass
if token:
headers = {"Authorization": f"Bearer {token}"}
if form_data.headers and isinstance(form_data.headers, dict):
if headers is None:
headers = {}
headers.update(form_data.headers)
url = get_tool_server_url(form_data.url, form_data.path) url = get_tool_server_url(form_data.url, form_data.path)
return await get_tool_server_data(token, url) return await get_tool_server_data(url, headers=headers)
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
@ -439,6 +463,7 @@ async def set_code_execution_config(
############################ ############################
class ModelsConfigForm(BaseModel): class ModelsConfigForm(BaseModel):
DEFAULT_MODELS: Optional[str] DEFAULT_MODELS: Optional[str]
DEFAULT_PINNED_MODELS: Optional[str]
MODEL_ORDER_LIST: Optional[list[str]] MODEL_ORDER_LIST: Optional[list[str]]
@ -446,6 +471,7 @@ class ModelsConfigForm(BaseModel):
async def get_models_config(request: Request, user=Depends(get_admin_user)): async def get_models_config(request: Request, user=Depends(get_admin_user)):
return { return {
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
"DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS,
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
} }
@ -455,9 +481,11 @@ async def set_models_config(
request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user) request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
): ):
request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
request.app.state.config.DEFAULT_PINNED_MODELS = form_data.DEFAULT_PINNED_MODELS
request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
return { return {
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
"DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS,
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
} }

View file

@ -7,6 +7,8 @@ from open_webui.models.feedbacks import (
FeedbackModel, FeedbackModel,
FeedbackResponse, FeedbackResponse,
FeedbackForm, FeedbackForm,
FeedbackUserResponse,
FeedbackListResponse,
Feedbacks, Feedbacks,
) )
@ -56,35 +58,10 @@ async def update_config(
} }
class UserResponse(BaseModel): @router.get("/feedbacks/all", response_model=list[FeedbackResponse])
id: str
name: str
email: str
role: str = "pending"
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class FeedbackUserResponse(FeedbackResponse):
user: Optional[UserResponse] = None
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)): async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks() feedbacks = Feedbacks.get_all_feedbacks()
return feedbacks
feedback_list = []
for feedback in feedbacks:
user = Users.get_user_by_id(feedback.user_id)
feedback_list.append(
FeedbackUserResponse(
**feedback.model_dump(),
user=UserResponse(**user.model_dump()) if user else None,
)
)
return feedback_list
@router.delete("/feedbacks/all") @router.delete("/feedbacks/all")
@ -111,6 +88,31 @@ async def delete_feedbacks(user=Depends(get_verified_user)):
return success return success
PAGE_ITEM_COUNT = 30
@router.get("/feedbacks/list", response_model=FeedbackListResponse)
async def get_feedbacks(
order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_admin_user),
):
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit)
return result
@router.post("/feedback", response_model=FeedbackModel) @router.post("/feedback", response_model=FeedbackModel)
async def create_feedback( async def create_feedback(
request: Request, request: Request,

View file

@ -102,7 +102,7 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
) )
): ):
file_path = Storage.get_file(file_path) file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata) result = transcribe(request, file_path, file_metadata, user)
process_file( process_file(
request, request,

View file

@ -258,7 +258,10 @@ async def update_folder_is_expanded_by_id(
@router.delete("/{id}") @router.delete("/{id}")
async def delete_folder_by_id( async def delete_folder_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request,
id: str,
delete_contents: Optional[bool] = True,
user=Depends(get_verified_user),
): ):
if Chats.count_chats_by_folder_id_and_user_id(id, user.id): if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
chat_delete_permission = has_permission( chat_delete_permission = has_permission(
@ -277,8 +280,14 @@ async def delete_folder_by_id(
if folder: if folder:
try: try:
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
for folder_id in folder_ids: for folder_id in folder_ids:
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) if delete_contents:
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
else:
Chats.move_chats_by_user_id_and_folder_id(
user.id, folder_id, None
)
return True return True
except Exception as e: except Exception as e:

View file

@ -31,11 +31,32 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse]) @router.get("/", response_model=list[GroupResponse])
async def get_groups(user=Depends(get_verified_user)): async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
return Groups.get_groups() groups = Groups.get_groups()
else: else:
return Groups.get_groups_by_member_id(user.id) groups = Groups.get_groups_by_member_id(user.id)
group_list = []
for group in groups:
if share is not None:
# Check if the group has data and a config with share key
if (
group.data
and "share" in group.data.get("config", {})
and group.data["config"]["share"] != share
):
continue
group_list.append(
GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
)
return group_list
############################ ############################
@ -48,7 +69,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try: try:
group = Groups.insert_new_group(user.id, form_data) group = Groups.insert_new_group(user.id, form_data)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -71,7 +95,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
async def get_group_by_id(id: str, user=Depends(get_admin_user)): async def get_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id) group = Groups.get_group_by_id(id)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -89,12 +116,12 @@ async def update_group_by_id(
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
): ):
try: try:
if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
group = Groups.update_group_by_id(id, form_data) group = Groups.update_group_by_id(id, form_data)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -123,7 +150,10 @@ async def add_user_to_group(
group = Groups.add_users_to_group(id, form_data.user_ids) group = Groups.add_users_to_group(id, form_data.user_ids)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -144,7 +174,10 @@ async def remove_users_from_group(
try: try:
group = Groups.remove_users_from_group(id, form_data.user_ids) group = Groups.remove_users_from_group(id, form_data.user_ids)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View file

@ -44,18 +44,23 @@ def set_image_model(request: Request, model: str):
request.app.state.config.IMAGE_GENERATION_MODEL = model request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]: if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth(request) api_auth = get_automatic1111_api_auth(request)
r = requests.get(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", try:
headers={"authorization": api_auth}, r = requests.get(
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth}, headers={"authorization": api_auth},
) )
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth},
)
except Exception as e:
log.debug(f"{e}")
return request.app.state.config.IMAGE_GENERATION_MODEL return request.app.state.config.IMAGE_GENERATION_MODEL
@ -106,9 +111,10 @@ class ImagesConfig(BaseModel):
IMAGES_OPENAI_API_BASE_URL: str IMAGES_OPENAI_API_BASE_URL: str
IMAGES_OPENAI_API_KEY: str IMAGES_OPENAI_API_KEY: str
IMAGES_OPENAI_API_VERSION: str IMAGES_OPENAI_API_VERSION: str
IMAGES_OPENAI_API_PARAMS: Optional[dict | str]
AUTOMATIC1111_BASE_URL: str AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str AUTOMATIC1111_API_AUTH: Optional[dict | str]
AUTOMATIC1111_PARAMS: Optional[dict | str] AUTOMATIC1111_PARAMS: Optional[dict | str]
COMFYUI_BASE_URL: str COMFYUI_BASE_URL: str
@ -120,6 +126,7 @@ class ImagesConfig(BaseModel):
IMAGES_GEMINI_API_KEY: str IMAGES_GEMINI_API_KEY: str
IMAGES_GEMINI_ENDPOINT_METHOD: str IMAGES_GEMINI_ENDPOINT_METHOD: str
ENABLE_IMAGE_EDIT: bool
IMAGE_EDIT_ENGINE: str IMAGE_EDIT_ENGINE: str
IMAGE_EDIT_MODEL: str IMAGE_EDIT_MODEL: str
IMAGE_EDIT_SIZE: Optional[str] IMAGE_EDIT_SIZE: Optional[str]
@ -147,6 +154,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, "IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION, "IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
"IMAGES_OPENAI_API_PARAMS": request.app.state.config.IMAGES_OPENAI_API_PARAMS,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS, "AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
@ -157,6 +165,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
"ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT,
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
@ -224,6 +233,9 @@ async def update_config(
request.app.state.config.IMAGES_OPENAI_API_VERSION = ( request.app.state.config.IMAGES_OPENAI_API_VERSION = (
form_data.IMAGES_OPENAI_API_VERSION form_data.IMAGES_OPENAI_API_VERSION
) )
request.app.state.config.IMAGES_OPENAI_API_PARAMS = (
form_data.IMAGES_OPENAI_API_PARAMS
)
request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL
request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
@ -243,15 +255,16 @@ async def update_config(
) )
# Edit Image # Edit Image
request.app.state.config.ENABLE_IMAGE_EDIT = form_data.ENABLE_IMAGE_EDIT
request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = ( request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
form_data.IMAGES_OPENAI_API_BASE_URL form_data.IMAGES_EDIT_OPENAI_API_BASE_URL
) )
request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = ( request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
form_data.IMAGES_OPENAI_API_KEY form_data.IMAGES_EDIT_OPENAI_API_KEY
) )
request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = ( request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
form_data.IMAGES_EDIT_OPENAI_API_VERSION form_data.IMAGES_EDIT_OPENAI_API_VERSION
@ -287,6 +300,7 @@ async def update_config(
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, "IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, "IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION, "IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
"IMAGES_OPENAI_API_PARAMS": request.app.state.config.IMAGES_OPENAI_API_PARAMS,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH, "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS, "AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
@ -297,6 +311,7 @@ async def update_config(
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD, "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
"ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT,
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE, "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL, "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE, "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
@ -534,6 +549,12 @@ async def image_generations(
if ENABLE_FORWARD_USER_INFO_HEADERS: if ENABLE_FORWARD_USER_INFO_HEADERS:
headers = include_user_info_headers(headers, user) headers = include_user_info_headers(headers, user)
url = (
f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
)
if request.app.state.config.IMAGES_OPENAI_API_VERSION:
url = f"{url}?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
data = { data = {
"model": model, "model": model,
"prompt": form_data.prompt, "prompt": form_data.prompt,
@ -548,18 +569,17 @@ async def image_generations(
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
else {"response_format": "b64_json"} else {"response_format": "b64_json"}
), ),
**(
{}
if not request.app.state.config.IMAGES_OPENAI_API_PARAMS
else request.app.state.config.IMAGES_OPENAI_API_PARAMS
),
} }
api_version_query_param = ""
if request.app.state.config.IMAGES_OPENAI_API_VERSION:
api_version_query_param = (
f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
)
# Use asyncio.to_thread for the requests.post call # Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread( r = await asyncio.to_thread(
requests.post, requests.post,
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}", url=url,
json=data, json=data,
headers=headers, headers=headers,
) )
@ -818,13 +838,13 @@ async def image_edits(
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
def get_image_file_item(base64_string): def get_image_file_item(base64_string, param_name="image"):
data = base64_string data = base64_string
header, encoded = data.split(",", 1) header, encoded = data.split(",", 1)
mime_type = header.split(";")[0].lstrip("data:") mime_type = header.split(";")[0].lstrip("data:")
image_data = base64.b64decode(encoded) image_data = base64.b64decode(encoded)
return ( return (
"image", param_name,
( (
f"{uuid.uuid4()}.png", f"{uuid.uuid4()}.png",
io.BytesIO(image_data), io.BytesIO(image_data),
@ -859,7 +879,7 @@ async def image_edits(
files = [get_image_file_item(form_data.image)] files = [get_image_file_item(form_data.image)]
elif isinstance(form_data.image, list): elif isinstance(form_data.image, list):
for img in form_data.image: for img in form_data.image:
files.append(get_image_file_item(img)) files.append(get_image_file_item(img, "image[]"))
url_search_params = "" url_search_params = ""
if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION: if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:

View file

@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
import asyncio
from typing import Optional from typing import Optional
from open_webui.models.memories import Memories, MemoryModel from open_webui.models.memories import Memories, MemoryModel
@ -17,7 +18,7 @@ router = APIRouter()
@router.get("/ef") @router.get("/ef")
async def get_embeddings(request: Request): async def get_embeddings(request: Request):
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")} return {"result": await request.app.state.EMBEDDING_FUNCTION("hello world")}
############################ ############################
@ -51,15 +52,15 @@ async def add_memory(
): ):
memory = Memories.insert_new_memory(user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content)
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
items=[ items=[
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": vector,
memory.content, user=user
),
"metadata": {"created_at": memory.created_at}, "metadata": {"created_at": memory.created_at},
} }
], ],
@ -86,9 +87,11 @@ async def query_memory(
if not memories: if not memories:
raise HTTPException(status_code=404, detail="No memories found for user") raise HTTPException(status_code=404, detail="No memories found for user")
vector = await request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)
results = VECTOR_DB_CLIENT.search( results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)], vectors=[vector],
limit=form_data.k, limit=form_data.k,
) )
@ -105,21 +108,28 @@ async def reset_memory_from_vector_db(
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id) memories = Memories.get_memories_by_user_id(user.id)
# Generate vectors in parallel
vectors = await asyncio.gather(
*[
request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
for memory in memories
]
)
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
items=[ items=[
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": vectors[idx],
memory.content, user=user
),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,
"updated_at": memory.updated_at, "updated_at": memory.updated_at,
}, },
} }
for memory in memories for idx, memory in enumerate(memories)
], ],
) )
@ -164,15 +174,15 @@ async def update_memory_by_id(
raise HTTPException(status_code=404, detail="Memory not found") raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None: if form_data.content is not None:
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
items=[ items=[
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": vector,
memory.content, user=user
),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,
"updated_at": memory.updated_at, "updated_at": memory.updated_at,

View file

@ -9,7 +9,7 @@ from open_webui.models.models import (
ModelForm, ModelForm,
ModelModel, ModelModel,
ModelResponse, ModelResponse,
ModelUserResponse, ModelListResponse,
Models, Models,
) )
@ -35,7 +35,7 @@ log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def validate_model_id(model_id: str) -> bool: def is_valid_model_id(model_id: str) -> bool:
return model_id and len(model_id) <= 256 return model_id and len(model_id) <= 256
@ -44,14 +44,43 @@ def validate_model_id(model_id: str) -> bool:
########################### ###########################
PAGE_ITEM_COUNT = 30
@router.get( @router.get(
"/list", response_model=list[ModelUserResponse] "/list", response_model=ModelListResponse
) # do NOT use "/" as path, conflicts with main.py ) # do NOT use "/" as path, conflicts with main.py
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): async def get_models(
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: query: Optional[str] = None,
return Models.get_models() view_option: Optional[str] = None,
else: tag: Optional[str] = None,
return Models.get_models_by_user_id(user.id) order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
):
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if view_option:
filter["view_option"] = view_option
if tag:
filter["tag"] = tag
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
filter["user_id"] = user.id
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit)
########################### ###########################
@ -64,6 +93,30 @@ async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models() return Models.get_base_models()
###########################
# GetModelTags
###########################
@router.get("/tags", response_model=list[str])
async def get_model_tags(user=Depends(get_verified_user)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
models = Models.get_models()
else:
models = Models.get_models_by_user_id(user.id)
tags_set = set()
for model in models:
if model.meta:
meta = model.meta.model_dump()
for tag in meta.get("tags", []):
tags_set.add((tag.get("name")))
tags = [tag for tag in tags_set]
tags.sort()
return tags
############################ ############################
# CreateNewModel # CreateNewModel
############################ ############################
@ -90,7 +143,7 @@ async def create_new_model(
detail=ERROR_MESSAGES.MODEL_ID_TAKEN, detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
) )
if not validate_model_id(form_data.id): if not is_valid_model_id(form_data.id):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG, detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG,
@ -113,8 +166,19 @@ async def create_new_model(
@router.get("/export", response_model=list[ModelModel]) @router.get("/export", response_model=list[ModelModel])
async def export_models(user=Depends(get_admin_user)): async def export_models(request: Request, user=Depends(get_verified_user)):
return Models.get_models() if user.role != "admin" and not has_permission(
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Models.get_models()
else:
return Models.get_models_by_user_id(user.id)
############################ ############################
@ -128,8 +192,17 @@ class ModelsImportForm(BaseModel):
@router.post("/import", response_model=bool) @router.post("/import", response_model=bool)
async def import_models( async def import_models(
user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...) request: Request,
user=Depends(get_verified_user),
form_data: ModelsImportForm = (...),
): ):
if user.role != "admin" and not has_permission(
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
try: try:
data = form_data.models data = form_data.models
if isinstance(data, list): if isinstance(data, list):
@ -137,7 +210,7 @@ async def import_models(
# Here, you can add logic to validate model_data if needed # Here, you can add logic to validate model_data if needed
model_id = model_data.get("id") model_id = model_data.get("id")
if model_id and validate_model_id(model_id): if model_id and is_valid_model_id(model_id):
existing_model = Models.get_model_by_id(model_id) existing_model = Models.get_model_by_id(model_id)
if existing_model: if existing_model:
# Update existing model # Update existing model
@ -183,6 +256,10 @@ async def sync_models(
########################### ###########################
class ModelIdForm(BaseModel):
id: str
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse]) @router.get("/model", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)): async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@ -229,6 +306,7 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
) )
except Exception as e: except Exception as e:
pass pass
return FileResponse(f"{STATIC_DIR}/favicon.png") return FileResponse(f"{STATIC_DIR}/favicon.png")
else: else:
return FileResponse(f"{STATIC_DIR}/favicon.png") return FileResponse(f"{STATIC_DIR}/favicon.png")
@ -276,12 +354,10 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/model/update", response_model=Optional[ModelModel]) @router.post("/model/update", response_model=Optional[ModelModel])
async def update_model_by_id( async def update_model_by_id(
id: str,
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(form_data.id)
if not model: if not model:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -298,7 +374,7 @@ async def update_model_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
model = Models.update_model_by_id(id, form_data) model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()))
return model return model
@ -307,9 +383,9 @@ async def update_model_by_id(
############################ ############################
@router.delete("/model/delete", response_model=bool) @router.post("/model/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_verified_user)): async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(form_data.id)
if not model: if not model:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -326,7 +402,7 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
result = Models.delete_model_by_id(id) result = Models.delete_model_by_id(form_data.id)
return result return result

View file

@ -16,8 +16,8 @@ from urllib.parse import urlparse
import aiohttp import aiohttp
from aiocache import cached from aiocache import cached
import requests import requests
from urllib.parse import quote
from open_webui.utils.headers import include_user_info_headers
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
@ -82,22 +82,17 @@ async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try: try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
headers = {
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with session.get( async with session.get(
url, url,
headers={ headers=headers,
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response: ) as response:
return await response.json() return await response.json()
@ -133,28 +128,20 @@ async def send_post_request(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) )
headers = {
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
if metadata and metadata.get("chat_id"):
headers["X-OpenWebUI-Chat-Id"] = metadata.get("chat_id")
r = await session.post( r = await session.post(
url, url,
data=payload, data=payload,
headers={ headers=headers,
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
**(
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
if metadata and metadata.get("chat_id")
else {}
),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) )
@ -246,21 +233,16 @@ async def verify_connection(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session: ) as session:
try: try:
headers = {
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with session.get( async with session.get(
f"{url}/api/version", f"{url}/api/version",
headers={ headers=headers,
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
if r.status != 200: if r.status != 200:
@ -469,22 +451,17 @@ async def get_ollama_tags(
r = None r = None
try: try:
headers = {
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.request( r = requests.request(
method="GET", method="GET",
url=f"{url}/api/tags", url=f"{url}/api/tags",
headers={ headers=headers,
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) )
r.raise_for_status() r.raise_for_status()
@ -838,23 +815,18 @@ async def copy_model(
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
headers = {
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/copy", url=f"{url}/api/copy",
headers={ headers=headers,
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
r.raise_for_status() r.raise_for_status()
@ -908,24 +880,19 @@ async def delete_model(
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
headers = {
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.request( r = requests.request(
method="DELETE", method="DELETE",
url=f"{url}/api/delete", url=f"{url}/api/delete",
data=json.dumps(form_data).encode(), headers=headers,
headers={ data=form_data.model_dump_json(exclude_none=True).encode(),
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) )
r.raise_for_status() r.raise_for_status()
@ -973,24 +940,19 @@ async def show_model_info(
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
headers = {
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/show", url=f"{url}/api/show",
headers={ headers=headers,
"Content-Type": "application/json", data=form_data.model_dump_json(exclude_none=True).encode(),
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=json.dumps(form_data).encode(),
) )
r.raise_for_status() r.raise_for_status()
@ -1064,23 +1026,18 @@ async def embed(
form_data.model = form_data.model.replace(f"{prefix_id}.", "") form_data.model = form_data.model.replace(f"{prefix_id}.", "")
try: try:
headers = {
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/embed", url=f"{url}/api/embed",
headers={ headers=headers,
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
r.raise_for_status() r.raise_for_status()
@ -1151,23 +1108,18 @@ async def embeddings(
form_data.model = form_data.model.replace(f"{prefix_id}.", "") form_data.model = form_data.model.replace(f"{prefix_id}.", "")
try: try:
headers = {
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/embeddings", url=f"{url}/api/embeddings",
headers={ headers=headers,
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
r.raise_for_status() r.raise_for_status()

View file

@ -7,7 +7,6 @@ from typing import Optional
import aiohttp import aiohttp
from aiocache import cached from aiocache import cached
import requests import requests
from urllib.parse import quote
from azure.identity import DefaultAzureCredential, get_bearer_token_provider from azure.identity import DefaultAzureCredential, get_bearer_token_provider
@ -45,10 +44,12 @@ from open_webui.utils.payload import (
) )
from open_webui.utils.misc import ( from open_webui.utils.misc import (
convert_logit_bias_input_to_json, convert_logit_bias_input_to_json,
stream_chunks_handler,
) )
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.utils.headers import include_user_info_headers
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -66,21 +67,16 @@ async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try: try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
headers = {
**({"Authorization": f"Bearer {key}"} if key else {}),
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with session.get( async with session.get(
url, url,
headers={ headers=headers,
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response: ) as response:
return await response.json() return await response.json()
@ -140,23 +136,13 @@ async def get_headers_and_cookies(
if "openrouter.ai" in url if "openrouter.ai" in url
else {} else {}
), ),
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
**(
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
if metadata and metadata.get("chat_id")
else {}
),
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
} }
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
if metadata and metadata.get("chat_id"):
headers["X-OpenWebUI-Chat-Id"] = metadata.get("chat_id")
token = None token = None
auth_type = config.get("auth_type") auth_type = config.get("auth_type")
@ -762,6 +748,7 @@ def get_azure_allowed_params(api_version: str) -> set[str]:
"response_format", "response_format",
"seed", "seed",
"max_completion_tokens", "max_completion_tokens",
"reasoning_effort",
} }
try: try:
@ -952,7 +939,7 @@ async def generate_chat_completion(
if "text/event-stream" in r.headers.get("Content-Type", ""): if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True streaming = True
return StreamingResponse( return StreamingResponse(
r.content, stream_chunks_handler(r.content),
status_code=r.status, status_code=r.status,
headers=dict(r.headers), headers=dict(r.headers),
background=BackgroundTask( background=BackgroundTask(

View file

@ -48,8 +48,15 @@ async def get_prompt_list(user=Depends(get_verified_user)):
async def create_new_prompt( async def create_new_prompt(
request: Request, form_data: PromptForm, user=Depends(get_verified_user) request: Request, form_data: PromptForm, user=Depends(get_verified_user)
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not (
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS has_permission(
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
)
or has_permission(
user.id,
"workspace.prompts_import",
request.app.state.config.USER_PERMISSIONS,
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View file

@ -32,7 +32,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSpl
from langchain_text_splitters import MarkdownHeaderTextSplitter from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.models.files import FileModel, Files from open_webui.models.files import FileModel, FileUpdateForm, Files
from open_webui.models.knowledge import Knowledges from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage from open_webui.storage.provider import Storage
@ -64,6 +64,7 @@ from open_webui.retrieval.web.serply import search_serply
from open_webui.retrieval.web.serpstack import search_serpstack from open_webui.retrieval.web.serpstack import search_serpstack
from open_webui.retrieval.web.tavily import search_tavily from open_webui.retrieval.web.tavily import search_tavily
from open_webui.retrieval.web.bing import search_bing from open_webui.retrieval.web.bing import search_bing
from open_webui.retrieval.web.azure import search_azure
from open_webui.retrieval.web.exa import search_exa from open_webui.retrieval.web.exa import search_exa
from open_webui.retrieval.web.perplexity import search_perplexity from open_webui.retrieval.web.perplexity import search_perplexity
from open_webui.retrieval.web.sougou import search_sougou from open_webui.retrieval.web.sougou import search_sougou
@ -430,6 +431,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
# Hybrid search settings # Hybrid search settings
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
"ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
@ -528,6 +530,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
"PERPLEXITY_SEARCH_API_URL": request.app.state.config.PERPLEXITY_SEARCH_API_URL,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
@ -585,6 +588,7 @@ class WebConfig(BaseModel):
PERPLEXITY_API_KEY: Optional[str] = None PERPLEXITY_API_KEY: Optional[str] = None
PERPLEXITY_MODEL: Optional[str] = None PERPLEXITY_MODEL: Optional[str] = None
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
PERPLEXITY_SEARCH_API_URL: Optional[str] = None
SOUGOU_API_SID: Optional[str] = None SOUGOU_API_SID: Optional[str] = None
SOUGOU_API_SK: Optional[str] = None SOUGOU_API_SK: Optional[str] = None
WEB_LOADER_ENGINE: Optional[str] = None WEB_LOADER_ENGINE: Optional[str] = None
@ -612,6 +616,7 @@ class ConfigForm(BaseModel):
# Hybrid search settings # Hybrid search settings
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS: Optional[bool] = None
TOP_K_RERANKER: Optional[int] = None TOP_K_RERANKER: Optional[int] = None
RELEVANCE_THRESHOLD: Optional[float] = None RELEVANCE_THRESHOLD: Optional[float] = None
HYBRID_BM25_WEIGHT: Optional[float] = None HYBRID_BM25_WEIGHT: Optional[float] = None
@ -718,6 +723,11 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
) )
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = (
form_data.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
if form_data.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
)
request.app.state.config.TOP_K_RERANKER = ( request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER form_data.TOP_K_RERANKER
@ -1108,6 +1118,9 @@ async def update_rag_config(
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = ( request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
) )
request.app.state.config.PERPLEXITY_SEARCH_API_URL = (
form_data.web.PERPLEXITY_SEARCH_API_URL
)
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
@ -1253,6 +1266,7 @@ async def update_rag_config(
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
"PERPLEXITY_SEARCH_API_URL": request.app.state.config.PERPLEXITY_SEARCH_API_URL,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
@ -1453,10 +1467,13 @@ def save_docs_to_vector_db(
), ),
) )
embeddings = embedding_function( # Run async embedding in sync context
list(map(lambda x: x.replace("\n", " "), texts)), embeddings = asyncio.run(
prefix=RAG_EMBEDDING_CONTENT_PREFIX, embedding_function(
user=user, list(map(lambda x: x.replace("\n", " "), texts)),
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
user=user,
)
) )
log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items") log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items")
@ -1811,7 +1828,9 @@ def process_web(
) )
def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: def search_web(
request: Request, engine: str, query: str, user=None
) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects. """Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order: Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL - SEARXNG_QUERY_URL
@ -1850,6 +1869,8 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
query, query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
request.app.state.config.PERPLEXITY_SEARCH_API_URL,
user,
) )
else: else:
raise Exception("No PERPLEXITY_API_KEY found in environment variables") raise Exception("No PERPLEXITY_API_KEY found in environment variables")
@ -2027,6 +2048,24 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
elif engine == "azure":
if (
request.app.state.config.AZURE_AI_SEARCH_API_KEY
and request.app.state.config.AZURE_AI_SEARCH_ENDPOINT
and request.app.state.config.AZURE_AI_SEARCH_INDEX_NAME
):
return search_azure(
request.app.state.config.AZURE_AI_SEARCH_API_KEY,
request.app.state.config.AZURE_AI_SEARCH_ENDPOINT,
request.app.state.config.AZURE_AI_SEARCH_INDEX_NAME,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception(
"AZURE_AI_SEARCH_API_KEY, AZURE_AI_SEARCH_ENDPOINT, and AZURE_AI_SEARCH_INDEX_NAME are required for Azure AI Search"
)
elif engine == "exa": elif engine == "exa":
return search_exa( return search_exa(
request.app.state.config.EXA_API_KEY, request.app.state.config.EXA_API_KEY,
@ -2069,11 +2108,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
) )
elif engine == "external": elif engine == "external":
return search_external( return search_external(
request,
request.app.state.config.EXTERNAL_WEB_SEARCH_URL, request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
query, query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
user=user,
) )
else: else:
raise Exception("No search engine API key found in environment variables") raise Exception("No search engine API key found in environment variables")
@ -2098,6 +2139,7 @@ async def process_web_search(
request, request,
request.app.state.config.WEB_SEARCH_ENGINE, request.app.state.config.WEB_SEARCH_ENGINE,
query, query,
user,
) )
for query in form_data.queries for query in form_data.queries
] ]
@ -2223,7 +2265,7 @@ class QueryDocForm(BaseModel):
@router.post("/query/doc") @router.post("/query/doc")
def query_doc_handler( async def query_doc_handler(
request: Request, request: Request,
form_data: QueryDocForm, form_data: QueryDocForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
@ -2236,7 +2278,7 @@ def query_doc_handler(
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
collection_name=form_data.collection_name collection_name=form_data.collection_name
) )
return query_doc_with_hybrid_search( return await query_doc_with_hybrid_search(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
collection_result=collection_results[form_data.collection_name], collection_result=collection_results[form_data.collection_name],
query=form_data.query, query=form_data.query,
@ -2246,8 +2288,8 @@ def query_doc_handler(
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=( reranking_function=(
( (
lambda sentences: request.app.state.RERANKING_FUNCTION( lambda query, documents: request.app.state.RERANKING_FUNCTION(
sentences, user=user query, documents, user=user
) )
) )
if request.app.state.RERANKING_FUNCTION if request.app.state.RERANKING_FUNCTION
@ -2268,11 +2310,12 @@ def query_doc_handler(
user=user, user=user,
) )
else: else:
query_embedding = await request.app.state.EMBEDDING_FUNCTION(
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
)
return query_doc( return query_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION( query_embedding=query_embedding,
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
user=user, user=user,
) )
@ -2292,10 +2335,11 @@ class QueryCollectionsForm(BaseModel):
r: Optional[float] = None r: Optional[float] = None
hybrid: Optional[bool] = None hybrid: Optional[bool] = None
hybrid_bm25_weight: Optional[float] = None hybrid_bm25_weight: Optional[float] = None
enable_enriched_texts: Optional[bool] = None
@router.post("/query/collection") @router.post("/query/collection")
def query_collection_handler( async def query_collection_handler(
request: Request, request: Request,
form_data: QueryCollectionsForm, form_data: QueryCollectionsForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
@ -2304,7 +2348,7 @@ def query_collection_handler(
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (
form_data.hybrid is None or form_data.hybrid form_data.hybrid is None or form_data.hybrid
): ):
return query_collection_with_hybrid_search( return await query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
@ -2313,8 +2357,8 @@ def query_collection_handler(
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=( reranking_function=(
( (
lambda sentences: request.app.state.RERANKING_FUNCTION( lambda query, documents: request.app.state.RERANKING_FUNCTION(
sentences, user=user query, documents, user=user
) )
) )
if request.app.state.RERANKING_FUNCTION if request.app.state.RERANKING_FUNCTION
@ -2332,9 +2376,14 @@ def query_collection_handler(
if form_data.hybrid_bm25_weight if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT else request.app.state.config.HYBRID_BM25_WEIGHT
), ),
enable_enriched_texts=(
form_data.enable_enriched_texts
if form_data.enable_enriched_texts is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
),
) )
else: else:
return query_collection( return await query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
@ -2416,7 +2465,7 @@ if ENV == "dev":
@router.get("/ef/{text}") @router.get("/ef/{text}")
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
return { return {
"result": request.app.state.EMBEDDING_FUNCTION( "result": await request.app.state.EMBEDDING_FUNCTION(
text, prefix=RAG_EMBEDDING_QUERY_PREFIX text, prefix=RAG_EMBEDDING_QUERY_PREFIX
) )
} }
@ -2447,16 +2496,19 @@ def process_files_batch(
""" """
Process a batch of files and save them to the vector database. Process a batch of files and save them to the vector database.
""" """
results: List[BatchProcessFilesResult] = []
errors: List[BatchProcessFilesResult] = []
collection_name = form_data.collection_name collection_name = form_data.collection_name
file_results: List[BatchProcessFilesResult] = []
file_errors: List[BatchProcessFilesResult] = []
file_updates: List[FileUpdateForm] = []
# Prepare all documents first # Prepare all documents first
all_docs: List[Document] = [] all_docs: List[Document] = []
for file in form_data.files: for file in form_data.files:
try: try:
text_content = file.data.get("content", "") text_content = file.data.get("content", "")
docs: List[Document] = [ docs: List[Document] = [
Document( Document(
page_content=text_content.replace("<br/>", "\n"), page_content=text_content.replace("<br/>", "\n"),
@ -2470,16 +2522,21 @@ def process_files_batch(
) )
] ]
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
Files.update_file_data_by_id(file.id, {"content": text_content})
all_docs.extend(docs) all_docs.extend(docs)
results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
file_updates.append(
FileUpdateForm(
hash=calculate_sha256_string(text_content),
data={"content": text_content},
)
)
file_results.append(
BatchProcessFilesResult(file_id=file.id, status="prepared")
)
except Exception as e: except Exception as e:
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}") log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
errors.append( file_errors.append(
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e)) BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
) )
@ -2495,20 +2552,18 @@ def process_files_batch(
) )
# Update all files with collection name # Update all files with collection name
for result in results: for file_update, file_result in zip(file_updates, file_results):
Files.update_file_metadata_by_id( Files.update_file_by_id(id=file_result.file_id, form_data=file_update)
result.file_id, {"collection_name": collection_name} file_result.status = "completed"
)
result.status = "completed"
except Exception as e: except Exception as e:
log.error( log.error(
f"process_files_batch: Error saving documents to vector DB: {str(e)}" f"process_files_batch: Error saving documents to vector DB: {str(e)}"
) )
for result in results: for file_result in file_results:
result.status = "failed" file_result.status = "failed"
errors.append( file_errors.append(
BatchProcessFilesResult(file_id=result.file_id, error=str(e)) BatchProcessFilesResult(file_id=file_result.file_id, error=str(e))
) )
return BatchProcessFilesResponse(results=results, errors=errors) return BatchProcessFilesResponse(results=file_results, errors=file_errors)

View file

@ -256,15 +256,16 @@ def get_scim_auth(
) )
# Check if SCIM is enabled # Check if SCIM is enabled
scim_enabled = getattr(request.app.state, "SCIM_ENABLED", False) enable_scim = getattr(request.app.state, "ENABLE_SCIM", False)
log.info( log.info(
f"SCIM auth check - raw SCIM_ENABLED: {scim_enabled}, type: {type(scim_enabled)}" f"SCIM auth check - raw ENABLE_SCIM: {enable_scim}, type: {type(enable_scim)}"
) )
# Handle both PersistentConfig and direct value # Handle both PersistentConfig and direct value
if hasattr(scim_enabled, "value"): if hasattr(enable_scim, "value"):
scim_enabled = scim_enabled.value enable_scim = enable_scim.value
log.info(f"SCIM enabled status after conversion: {scim_enabled}")
if not scim_enabled: if not enable_scim:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="SCIM is not enabled", detail="SCIM is not enabled",
@ -348,8 +349,10 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
"""Convert internal Group model to SCIM Group""" """Convert internal Group model to SCIM Group"""
member_ids = Groups.get_group_user_ids_by_id(group.id)
members = [] members = []
for user_id in group.user_ids:
for user_id in member_ids:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
if user: if user:
members.append( members.append(
@ -795,9 +798,11 @@ async def create_group(
update_form = GroupUpdateForm( update_form = GroupUpdateForm(
name=new_group.name, name=new_group.name,
description=new_group.description, description=new_group.description,
user_ids=member_ids,
) )
Groups.update_group_by_id(new_group.id, update_form) Groups.update_group_by_id(new_group.id, update_form)
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
new_group = Groups.get_group_by_id(new_group.id) new_group = Groups.get_group_by_id(new_group.id)
return group_to_scim(new_group, request) return group_to_scim(new_group, request)
@ -829,7 +834,7 @@ async def update_group(
# Handle members if provided # Handle members if provided
if group_data.members is not None: if group_data.members is not None:
member_ids = [member.value for member in group_data.members] member_ids = [member.value for member in group_data.members]
update_form.user_ids = member_ids Groups.set_group_user_ids_by_id(group_id, member_ids)
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = Groups.update_group_by_id(group_id, update_form)
@ -862,7 +867,6 @@ async def patch_group(
update_form = GroupUpdateForm( update_form = GroupUpdateForm(
name=group.name, name=group.name,
description=group.description, description=group.description,
user_ids=group.user_ids.copy() if group.user_ids else [],
) )
for operation in patch_data.Operations: for operation in patch_data.Operations:
@ -875,21 +879,22 @@ async def patch_group(
update_form.name = value update_form.name = value
elif path == "members": elif path == "members":
# Replace all members # Replace all members
update_form.user_ids = [member["value"] for member in value] Groups.set_group_user_ids_by_id(
group_id, [member["value"] for member in value]
)
elif op == "add": elif op == "add":
if path == "members": if path == "members":
# Add members # Add members
if isinstance(value, list): if isinstance(value, list):
for member in value: for member in value:
if isinstance(member, dict) and "value" in member: if isinstance(member, dict) and "value" in member:
if member["value"] not in update_form.user_ids: Groups.add_users_to_group(group_id, [member["value"]])
update_form.user_ids.append(member["value"])
elif op == "remove": elif op == "remove":
if path and path.startswith("members[value eq"): if path and path.startswith("members[value eq"):
# Remove specific member # Remove specific member
member_id = path.split('"')[1] member_id = path.split('"')[1]
if member_id in update_form.user_ids: Groups.remove_users_from_group(group_id, [member_id])
update_form.user_ids.remove(member_id)
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = Groups.update_group_by_id(group_id, update_form)

View file

@ -33,6 +33,7 @@ from open_webui.config import (
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -68,6 +69,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
} }
@ -87,6 +89,7 @@ class TaskConfigForm(BaseModel):
ENABLE_RETRIEVAL_QUERY_GENERATION: bool ENABLE_RETRIEVAL_QUERY_GENERATION: bool
QUERY_GENERATION_PROMPT_TEMPLATE: str QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
VOICE_MODE_PROMPT_TEMPLATE: Optional[str]
@router.post("/config/update") @router.post("/config/update")
@ -136,6 +139,10 @@ async def update_task_config(
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
) )
request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = (
form_data.VOICE_MODE_PROMPT_TEMPLATE
)
return { return {
"TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
@ -152,6 +159,7 @@ async def update_task_config(
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
} }

View file

@ -247,9 +247,19 @@ async def load_tool_from_url(
@router.get("/export", response_model=list[ToolModel]) @router.get("/export", response_model=list[ToolModel])
async def export_tools(user=Depends(get_admin_user)): async def export_tools(request: Request, user=Depends(get_verified_user)):
tools = Tools.get_tools() if user.role != "admin" and not has_permission(
return tools user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Tools.get_tools()
else:
return Tools.get_tools_by_user_id(user.id, "read")
############################ ############################
@ -263,8 +273,13 @@ async def create_new_tools(
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not (
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS has_permission(
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
)
or has_permission(
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View file

@ -16,6 +16,7 @@ from open_webui.models.groups import Groups
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.users import ( from open_webui.models.users import (
UserModel, UserModel,
UserGroupIdsModel,
UserListResponse, UserListResponse,
UserInfoListResponse, UserInfoListResponse,
UserIdNameListResponse, UserIdNameListResponse,
@ -35,7 +36,12 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user from open_webui.utils.auth import (
get_admin_user,
get_password_hash,
get_verified_user,
validate_password,
)
from open_webui.utils.access_control import get_permissions, has_permission from open_webui.utils.access_control import get_permissions, has_permission
@ -91,7 +97,25 @@ async def get_users(
if direction: if direction:
filter["direction"] = direction filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit) result = Users.get_users(filter=filter, skip=skip, limit=limit)
users = result["users"]
total = result["total"]
return {
"users": [
UserGroupIdsModel(
**{
**user.model_dump(),
"group_ids": [
group.id for group in Groups.get_groups_by_member_id(user.id)
],
}
)
for user in users
],
"total": total,
}
@router.get("/all", response_model=UserInfoListResponse) @router.get("/all", response_model=UserInfoListResponse)
@ -150,13 +174,24 @@ class WorkspacePermissions(BaseModel):
knowledge: bool = False knowledge: bool = False
prompts: bool = False prompts: bool = False
tools: bool = False tools: bool = False
models_import: bool = False
models_export: bool = False
prompts_import: bool = False
prompts_export: bool = False
tools_import: bool = False
tools_export: bool = False
class SharingPermissions(BaseModel): class SharingPermissions(BaseModel):
public_models: bool = True models: bool = False
public_knowledge: bool = True public_models: bool = False
public_prompts: bool = True knowledge: bool = False
public_knowledge: bool = False
prompts: bool = False
public_prompts: bool = False
tools: bool = False
public_tools: bool = True public_tools: bool = True
notes: bool = False
public_notes: bool = True public_notes: bool = True
@ -183,6 +218,7 @@ class ChatPermissions(BaseModel):
class FeaturesPermissions(BaseModel): class FeaturesPermissions(BaseModel):
api_keys: bool = False
direct_tool_servers: bool = False direct_tool_servers: bool = False
web_search: bool = True web_search: bool = True
image_generation: bool = True image_generation: bool = True
@ -471,8 +507,12 @@ async def update_user_by_id(
) )
if form_data.password: if form_data.password:
try:
validate_password(form_data.password)
except Exception as e:
raise HTTPException(400, detail=str(e))
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(user_id, hashed) Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower()) Auths.update_email_by_id(user_id, form_data.email.lower())

View file

@ -124,12 +124,3 @@ async def download_db(user=Depends(get_admin_user)):
media_type="application/octet-stream", media_type="application/octet-stream",
filename="webui.db", filename="webui.db",
) )
@router.get("/litellm/config")
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
return FileResponse(
f"{DATA_DIR}/litellm/config.yaml",
media_type="application/octet-stream",
filename="config.yaml",
)

View file

@ -32,6 +32,11 @@ from open_webui.env import (
WEBSOCKET_SENTINEL_PORT, WEBSOCKET_SENTINEL_PORT,
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_HOSTS,
REDIS_KEY_PREFIX, REDIS_KEY_PREFIX,
WEBSOCKET_REDIS_OPTIONS,
WEBSOCKET_SERVER_PING_TIMEOUT,
WEBSOCKET_SERVER_PING_INTERVAL,
WEBSOCKET_SERVER_LOGGING,
WEBSOCKET_SERVER_ENGINEIO_LOGGING,
) )
from open_webui.utils.auth import decode_token from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock, YdocManager from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
@ -61,10 +66,13 @@ if WEBSOCKET_MANAGER == "redis":
mgr = socketio.AsyncRedisManager( mgr = socketio.AsyncRedisManager(
get_sentinel_url_from_env( get_sentinel_url_from_env(
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
) ),
redis_options=WEBSOCKET_REDIS_OPTIONS,
) )
else: else:
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL) mgr = socketio.AsyncRedisManager(
WEBSOCKET_REDIS_URL, redis_options=WEBSOCKET_REDIS_OPTIONS
)
sio = socketio.AsyncServer( sio = socketio.AsyncServer(
cors_allowed_origins=SOCKETIO_CORS_ORIGINS, cors_allowed_origins=SOCKETIO_CORS_ORIGINS,
async_mode="asgi", async_mode="asgi",
@ -72,6 +80,10 @@ if WEBSOCKET_MANAGER == "redis":
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True, always_connect=True,
client_manager=mgr, client_manager=mgr,
logger=WEBSOCKET_SERVER_LOGGING,
ping_interval=WEBSOCKET_SERVER_PING_INTERVAL,
ping_timeout=WEBSOCKET_SERVER_PING_TIMEOUT,
engineio_logger=WEBSOCKET_SERVER_ENGINEIO_LOGGING,
) )
else: else:
sio = socketio.AsyncServer( sio = socketio.AsyncServer(
@ -80,6 +92,10 @@ else:
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]), transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT, allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True, always_connect=True,
logger=WEBSOCKET_SERVER_LOGGING,
ping_interval=WEBSOCKET_SERVER_PING_INTERVAL,
ping_timeout=WEBSOCKET_SERVER_PING_TIMEOUT,
engineio_logger=WEBSOCKET_SERVER_ENGINEIO_LOGGING,
) )
@ -282,6 +298,8 @@ async def connect(sid, environ, auth):
else: else:
USER_POOL[user.id] = [sid] USER_POOL[user.id] = [sid]
await sio.enter_room(sid, f"user:{user.id}")
@sio.on("user-join") @sio.on("user-join")
async def user_join(sid, data): async def user_join(sid, data):
@ -304,6 +322,7 @@ async def user_join(sid, data):
else: else:
USER_POOL[user.id] = [sid] USER_POOL[user.id] = [sid]
await sio.enter_room(sid, f"user:{user.id}")
# Join all the channels # Join all the channels
channels = Channels.get_channels_by_user_id(user.id) channels = Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}") log.debug(f"{channels=}")
@ -649,40 +668,24 @@ async def disconnect(sid):
def get_event_emitter(request_info, update_db=True): def get_event_emitter(request_info, update_db=True):
async def __event_emitter__(event_data): async def __event_emitter__(event_data):
user_id = request_info["user_id"] user_id = request_info["user_id"]
chat_id = request_info["chat_id"]
message_id = request_info["message_id"]
session_ids = list( await sio.emit(
set( "events",
USER_POOL.get(user_id, []) {
+ ( "chat_id": chat_id,
[request_info.get("session_id")] "message_id": message_id,
if request_info.get("session_id") "data": event_data,
else [] },
) room=f"user:{user_id}",
)
) )
chat_id = request_info.get("chat_id", None)
message_id = request_info.get("message_id", None)
emit_tasks = [
sio.emit(
"events",
{
"chat_id": chat_id,
"message_id": message_id,
"data": event_data,
},
to=session_id,
)
for session_id in session_ids
]
await asyncio.gather(*emit_tasks)
if ( if (
update_db update_db
and message_id and message_id
and not request_info.get("chat_id", "").startswith("local:") and not request_info.get("chat_id", "").startswith("local:")
): ):
if "type" in event_data and event_data["type"] == "status": if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id( Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"], request_info["chat_id"],
@ -772,7 +775,14 @@ def get_event_emitter(request_info, update_db=True):
}, },
) )
return __event_emitter__ if (
"user_id" in request_info
and "chat_id" in request_info
and "message_id" in request_info
):
return __event_emitter__
else:
return None
def get_event_call(request_info): def get_event_call(request_info):
@ -788,7 +798,14 @@ def get_event_call(request_info):
) )
return response return response
return __event_caller__ if (
"session_id" in request_info
and "chat_id" in request_info
and "message_id" in request_info
):
return __event_caller__
else:
return None
get_event_caller = get_event_call get_event_caller = get_event_call

View file

@ -21,13 +21,18 @@ from typing import Optional, Union, List, Dict
from opentelemetry import trace from opentelemetry import trace
from open_webui.utils.access_control import has_permission
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ( from open_webui.env import (
ENABLE_PASSWORD_VALIDATION,
OFFLINE_MODE, OFFLINE_MODE,
LICENSE_BLOB, LICENSE_BLOB,
PASSWORD_VALIDATION_REGEX_PATTERN,
REDIS_KEY_PREFIX,
pk, pk,
WEBUI_SECRET_KEY, WEBUI_SECRET_KEY,
TRUSTED_SIGNATURE_KEY, TRUSTED_SIGNATURE_KEY,
@ -159,6 +164,20 @@ def get_password_hash(password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def validate_password(password: str) -> bool:
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
if len(password.encode("utf-8")) > 72:
raise Exception(
ERROR_MESSAGES.PASSWORD_TOO_LONG,
)
if ENABLE_PASSWORD_VALIDATION:
if not PASSWORD_VALIDATION_REGEX_PATTERN.match(password):
raise Exception(ERROR_MESSAGES.INVALID_PASSWORD())
return True
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash""" """Verify a password against its hash"""
return ( return (
@ -178,6 +197,9 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
expire = datetime.now(UTC) + expires_delta expire = datetime.now(UTC) + expires_delta
payload.update({"exp": expire}) payload.update({"exp": expire})
jti = str(uuid.uuid4())
payload.update({"jti": jti})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM) encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
@ -190,6 +212,43 @@ def decode_token(token: str) -> Optional[dict]:
return None return None
async def is_valid_token(request, decoded) -> bool:
# Require Redis to check revoked tokens
if request.app.state.redis:
jti = decoded.get("jti")
if jti:
revoked = await request.app.state.redis.get(
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked"
)
if revoked:
return False
return True
async def invalidate_token(request, token):
decoded = decode_token(token)
# Require Redis to store revoked tokens
if request.app.state.redis:
jti = decoded.get("jti")
exp = decoded.get("exp")
if jti:
ttl = exp - int(
datetime.now(UTC).timestamp()
) # Calculate time-to-live for the token
if ttl > 0:
# Store the revoked token in Redis with an expiration time
await request.app.state.redis.set(
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked",
"1",
ex=ttl,
)
def extract_token_from_auth_header(auth_header: str): def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
@ -209,7 +268,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
return None return None
def get_current_user( async def get_current_user(
request: Request, request: Request,
response: Response, response: Response,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@ -228,30 +287,7 @@ def get_current_user(
# auth by api key # auth by api key
if token.startswith("sk-"): if token.startswith("sk-"):
if not request.state.enable_api_key: user = get_current_user_by_api_key(request, token)
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
).split(",")
]
# Check if the request path matches any allowed endpoint.
if not any(
request.url.path == allowed
or request.url.path.startswith(allowed + "/")
for allowed in allowed_paths
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
user = get_current_user_by_api_key(token)
# Add user info to current span # Add user info to current span
current_span = trace.get_current_span() current_span = trace.get_current_span()
@ -264,7 +300,6 @@ def get_current_user(
return user return user
# auth by jwt token # auth by jwt token
try: try:
try: try:
data = decode_token(token) data = decode_token(token)
@ -275,6 +310,12 @@ def get_current_user(
) )
if data is not None and "id" in data: if data is not None and "id" in data:
if data.get("jti") and not await is_valid_token(request, data):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
)
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user is None: if user is None:
raise HTTPException( raise HTTPException(
@ -327,7 +368,7 @@ def get_current_user(
raise e raise e
def get_current_user_by_api_key(api_key: str): def get_current_user_by_api_key(request, api_key: str):
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(api_key)
if user is None: if user is None:
@ -335,16 +376,28 @@ def get_current_user_by_api_key(api_key: str):
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else:
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id) if not request.state.enable_api_keys or (
user.role != "admin"
and not has_permission(
user.id,
"features.api_keys",
request.app.state.config.USER_PERMISSIONS,
)
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id)
return user return user

View file

@ -16,10 +16,15 @@ from open_webui.routers.files import upload_file_handler
import mimetypes import mimetypes
import base64 import base64
import io import io
import re
BASE64_IMAGE_URL_PREFIX = re.compile(r"data:image/\w+;base64,", re.IGNORECASE)
MARKDOWN_IMAGE_URL_PATTERN = re.compile(r"!\[(.*?)\]\((.+?)\)", re.IGNORECASE)
def get_image_url_from_base64(request, base64_image_string, metadata, user): def get_image_url_from_base64(request, base64_image_string, metadata, user):
if "data:image/png;base64" in base64_image_string: if BASE64_IMAGE_URL_PREFIX.match(base64_image_string):
image_url = "" image_url = ""
# Extract base64 image data from the line # Extract base64 image data from the line
image_data, content_type = get_image_data(base64_image_string) image_data, content_type = get_image_data(base64_image_string)
@ -35,6 +40,19 @@ def get_image_url_from_base64(request, base64_image_string, metadata, user):
return None return None
def convert_markdown_base64_images(request, content: str, metadata, user):
def replace(match):
base64_string = match.group(2)
MIN_REPLACEMENT_URL_LENGTH = 1024
if len(base64_string) > MIN_REPLACEMENT_URL_LENGTH:
url = get_image_url_from_base64(request, base64_string, metadata, user)
if url:
return f"![{match.group(1)}]({url})"
return match.group(0)
return MARKDOWN_IMAGE_URL_PATTERN.sub(replace, content)
def load_b64_audio_data(b64_str): def load_b64_audio_data(b64_str):
try: try:
if "," in b64_str: if "," in b64_str:

View file

@ -58,7 +58,7 @@ from open_webui.routers.memories import query_memory, QueryMemoryForm
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from open_webui.utils.files import ( from open_webui.utils.files import (
get_audio_url_from_base64, convert_markdown_base64_images,
get_file_url_from_base64, get_file_url_from_base64,
get_image_url_from_base64, get_image_url_from_base64,
) )
@ -104,6 +104,7 @@ from open_webui.utils.mcp.client import MCPClient
from open_webui.config import ( from open_webui.config import (
CACHE_DIR, CACHE_DIR,
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE,
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
DEFAULT_CODE_INTERPRETER_PROMPT, DEFAULT_CODE_INTERPRETER_PROMPT,
CODE_INTERPRETER_BLOCKED_MODULES, CODE_INTERPRETER_BLOCKED_MODULES,
@ -111,6 +112,7 @@ from open_webui.config import (
from open_webui.env import ( from open_webui.env import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION,
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES, CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES,
BYPASS_MODEL_ACCESS_CONTROL, BYPASS_MODEL_ACCESS_CONTROL,
@ -302,19 +304,27 @@ async def chat_completion_tools_handler(
def get_tools_function_calling_payload(messages, task_model_id, content): def get_tools_function_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages) user_message = get_last_user_message(messages)
if user_message and messages and messages[-1]["role"] == "user":
# Remove the last user message to avoid duplication
messages = messages[:-1]
recent_messages = messages[-4:] if len(messages) > 4 else messages recent_messages = messages[-4:] if len(messages) > 4 else messages
chat_history = "\n".join( chat_history = "\n".join(
f"{message['role'].upper()}: \"\"\"{get_content_from_message(message)}\"\"\"" f"{message['role'].upper()}: \"\"\"{get_content_from_message(message)}\"\"\""
for message in recent_messages for message in recent_messages
) )
prompt = f"History:\n{chat_history}\nQuery: {user_message}" prompt = (
f"History:\n{chat_history}\nQuery: {user_message}"
if chat_history
else f"Query: {user_message}"
)
return { return {
"model": task_model_id, "model": task_model_id,
"messages": [ "messages": [
{"role": "system", "content": content}, {"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"}, {"role": "user", "content": prompt},
], ],
"stream": False, "stream": False,
"metadata": {"task": str(TASKS.FUNCTION_CALLING)}, "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
@ -733,6 +743,27 @@ def get_last_images(message_list):
return images return images
def get_image_urls(delta_images, request, metadata, user) -> list[str]:
if not isinstance(delta_images, list):
return []
image_urls = []
for img in delta_images:
if not isinstance(img, dict) or img.get("type") != "image_url":
continue
url = img.get("image_url", {}).get("url")
if not url:
continue
if url.startswith("data:image/png;base64"):
url = get_image_url_from_base64(request, url, metadata, user)
image_urls.append(url)
return image_urls
async def chat_image_generation_handler( async def chat_image_generation_handler(
request: Request, form_data: dict, extra_params: dict, user request: Request, form_data: dict, extra_params: dict, user
): ):
@ -760,42 +791,13 @@ async def chat_image_generation_handler(
input_images = get_last_images(message_list) input_images = get_last_images(message_list)
system_message_content = "" system_message_content = ""
if len(input_images) == 0:
# Create image(s)
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": form_data["messages"],
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
prompt = user_message
except Exception as e:
log.exception(e)
prompt = user_message
if len(input_images) > 0 and request.app.state.config.ENABLE_IMAGE_EDIT:
# Edit image(s)
try: try:
images = await image_generations( images = await image_edits(
request=request, request=request,
form_data=CreateImageForm(**{"prompt": prompt}), form_data=EditImageForm(**{"prompt": prompt, "image": input_images}),
user=user, user=user,
) )
@ -843,12 +845,43 @@ async def chat_image_generation_handler(
) )
system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>" system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>"
else: else:
# Edit image(s) # Create image(s)
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": form_data["messages"],
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
prompt = user_message
except Exception as e:
log.exception(e)
prompt = user_message
try: try:
images = await image_edits( images = await image_generations(
request=request, request=request,
form_data=EditImageForm(**{"prompt": prompt, "image": input_images}), form_data=CreateImageForm(**{"prompt": prompt}),
user=user, user=user,
) )
@ -960,37 +993,32 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])] queries = [get_last_user_message(body["messages"])]
try: try:
# Offload get_sources_from_items to a separate thread # Directly await async get_sources_from_items (no thread needed - fully async now)
loop = asyncio.get_running_loop() sources = await get_sources_from_items(
with ThreadPoolExecutor() as executor: request=request,
sources = await loop.run_in_executor( items=files,
executor, queries=queries,
lambda: get_sources_from_items( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
request=request, query, prefix=prefix, user=user
items=files, ),
queries=queries, k=request.app.state.config.TOP_K,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( reranking_function=(
query, prefix=prefix, user=user (
), lambda query, documents: request.app.state.RERANKING_FUNCTION(
k=request.app.state.config.TOP_K, query, documents, user=user
reranking_function=( )
( )
lambda sentences: request.app.state.RERANKING_FUNCTION( if request.app.state.RERANKING_FUNCTION
sentences, user=user else None
) ),
) k_reranker=request.app.state.config.TOP_K_RERANKER,
if request.app.state.RERANKING_FUNCTION r=request.app.state.config.RELEVANCE_THRESHOLD,
else None hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
), hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
k_reranker=request.app.state.config.TOP_K_RERANKER, full_context=all_full_context
r=request.app.state.config.RELEVANCE_THRESHOLD, or request.app.state.config.RAG_FULL_CONTEXT,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, user=user,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, )
full_context=all_full_context
or request.app.state.config.RAG_FULL_CONTEXT,
user=user,
),
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -1097,7 +1125,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
pass pass
event_emitter = get_event_emitter(metadata) event_emitter = get_event_emitter(metadata)
event_call = get_event_call(metadata) event_caller = get_event_call(metadata)
oauth_token = None oauth_token = None
try: try:
@ -1111,14 +1139,13 @@ async def process_chat_payload(request, form_data, user, metadata, model):
extra_params = { extra_params = {
"__event_emitter__": event_emitter, "__event_emitter__": event_emitter,
"__event_call__": event_call, "__event_call__": event_caller,
"__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata, "__metadata__": metadata,
"__oauth_token__": oauth_token,
"__request__": request, "__request__": request,
"__model__": model, "__model__": model,
"__oauth_token__": oauth_token,
} }
# Initialize events to store additional event to be sent to the client # Initialize events to store additional event to be sent to the client
# Initialize contexts and citation # Initialize contexts and citation
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
@ -1229,6 +1256,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
features = form_data.pop("features", None) features = form_data.pop("features", None)
if features: if features:
if "voice" in features and features["voice"]:
if request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE != None:
if request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE != "":
template = request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE
else:
template = DEFAULT_VOICE_MODE_PROMPT_TEMPLATE
form_data["messages"] = add_or_update_system_message(
template,
form_data["messages"],
)
if "memory" in features and features["memory"]: if "memory" in features and features["memory"]:
form_data = await chat_memory_handler( form_data = await chat_memory_handler(
request, form_data, extra_params, user request, form_data, extra_params, user
@ -1323,7 +1362,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
continue continue
auth_type = mcp_server_connection.get("auth_type", "") auth_type = mcp_server_connection.get("auth_type", "")
headers = {} headers = {}
if auth_type == "bearer": if auth_type == "bearer":
headers["Authorization"] = ( headers["Authorization"] = (
@ -1359,6 +1397,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.error(f"Error getting OAuth token: {e}") log.error(f"Error getting OAuth token: {e}")
oauth_token = None oauth_token = None
connection_headers = mcp_server_connection.get("headers", None)
if connection_headers and isinstance(connection_headers, dict):
for key, value in connection_headers.items():
headers[key] = value
mcp_clients[server_id] = MCPClient() mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect( await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""), url=mcp_server_connection.get("url", ""),
@ -2556,6 +2599,26 @@ async def process_chat_response(
"arguments" "arguments"
] += delta_arguments ] += delta_arguments
image_urls = get_image_urls(
delta.get("images", []), request, metadata, user
)
if image_urls:
message_files = Chats.add_message_files_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
[
{"type": "image", "url": url}
for url in image_urls
],
)
await event_emitter(
{
"type": "files",
"data": {"files": message_files},
}
)
value = delta.get("content") value = delta.get("content")
reasoning_content = ( reasoning_content = (
@ -2614,6 +2677,11 @@ async def process_chat_response(
} }
) )
if ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION:
value = convert_markdown_base64_images(
request, value, metadata, user
)
content = f"{content}{value}" content = f"{content}{value}"
if not content_blocks: if not content_blocks:
content_blocks.append( content_blocks.append(

View file

@ -8,10 +8,11 @@ from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
import json import json
import aiohttp
import collections.abc import collections.abc
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS, CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
@ -539,3 +540,68 @@ def extract_urls(text: str) -> list[str]:
r"(https?://[^\s]+)", re.IGNORECASE r"(https?://[^\s]+)", re.IGNORECASE
) # Matches http and https URLs ) # Matches http and https URLs
return url_pattern.findall(text) return url_pattern.findall(text)
def stream_chunks_handler(stream: aiohttp.StreamReader):
"""
Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit.
When a single line exceeds max_buffer_size, returns an empty JSON string {} and skips subsequent data
until encountering normally sized data.
:param stream: The stream reader to handle.
:return: An async generator that yields the stream data.
"""
max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
if max_buffer_size is None or max_buffer_size <= 0:
return stream
async def yield_safe_stream_chunks():
buffer = b""
skip_mode = False
async for data, _ in stream.iter_chunks():
if not data:
continue
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
if skip_mode and len(buffer) > max_buffer_size:
buffer = b""
lines = (buffer + data).split(b"\n")
# Process complete lines (except the last possibly incomplete fragment)
for i in range(len(lines) - 1):
line = lines[i]
if skip_mode:
# Skip mode: check if current line is small enough to exit skip mode
if len(line) <= max_buffer_size:
skip_mode = False
yield line
else:
yield b"data: {}"
else:
# Normal mode: check if line exceeds limit
if len(line) > max_buffer_size:
skip_mode = True
yield b"data: {}"
log.info(f"Skip mode triggered, line size: {len(line)}")
else:
yield line
# Save the last incomplete fragment
buffer = lines[-1]
# Check if buffer exceeds limit
if not skip_mode and len(buffer) > max_buffer_size:
skip_mode = True
log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
# Clear oversized buffer to prevent unlimited growth
buffer = b""
# Process remaining buffer data
if buffer and not skip_mode:
yield buffer
return yield_safe_stream_chunks()

View file

@ -12,6 +12,7 @@ from open_webui.functions import get_function_models
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.models.groups import Groups
from open_webui.utils.plugin import ( from open_webui.utils.plugin import (
@ -356,6 +357,7 @@ def get_filtered_models(models, user):
or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL) or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
) and not BYPASS_MODEL_ACCESS_CONTROL: ) and not BYPASS_MODEL_ACCESS_CONTROL:
filtered_models = [] filtered_models = []
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
for model in models: for model in models:
if model.get("arena"): if model.get("arena"):
if has_access( if has_access(
@ -364,6 +366,7 @@ def get_filtered_models(models, user):
access_control=model.get("info", {}) access_control=model.get("info", {})
.get("meta", {}) .get("meta", {})
.get("access_control", {}), .get("access_control", {}),
user_group_ids=user_group_ids,
): ):
filtered_models.append(model) filtered_models.append(model)
continue continue
@ -377,6 +380,7 @@ def get_filtered_models(models, user):
user.id, user.id,
type="read", type="read",
access_control=model_info.access_control, access_control=model_info.access_control,
user_group_ids=user_group_ids,
) )
): ):
filtered_models.append(model) filtered_models.append(model)

View file

@ -14,7 +14,7 @@ import fnmatch
import time import time
import secrets import secrets
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from typing import Literal
import aiohttp import aiohttp
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
@ -72,13 +72,20 @@ from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from mcp.shared.auth import ( from mcp.shared.auth import (
OAuthClientMetadata, OAuthClientMetadata as MCPOAuthClientMetadata,
OAuthMetadata, OAuthMetadata,
) )
from authlib.oauth2.rfc6749.errors import OAuth2Error from authlib.oauth2.rfc6749.errors import OAuth2Error
class OAuthClientMetadata(MCPOAuthClientMetadata):
token_endpoint_auth_method: Literal[
"none", "client_secret_basic", "client_secret_post"
] = "client_secret_post"
pass
class OAuthClientInformationFull(OAuthClientMetadata): class OAuthClientInformationFull(OAuthClientMetadata):
issuer: Optional[str] = None # URL of the OAuth server that issued this client issuer: Optional[str] = None # URL of the OAuth server that issued this client
@ -238,24 +245,33 @@ def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
def get_discovery_urls(server_url) -> list[str]: def get_discovery_urls(server_url) -> list[str]:
parsed, base_url = get_parsed_and_base_url(server_url) parsed, base_url = get_parsed_and_base_url(server_url)
urls = [ urls = []
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
]
if parsed.path and parsed.path != "/": if parsed.path and parsed.path != "/":
urls.append( # Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
urllib.parse.urljoin( tenant = parsed.path.rstrip("/")
base_url, urls.extend(
f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}", [
) urllib.parse.urljoin(
) base_url,
urls.append( f"/.well-known/oauth-authorization-server{tenant}",
urllib.parse.urljoin( ),
base_url, f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" urllib.parse.urljoin(
) base_url, f"/.well-known/openid-configuration{tenant}"
),
urllib.parse.urljoin(
base_url, f"{tenant}/.well-known/openid-configuration"
),
]
) )
urls.extend(
[
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
]
)
return urls return urls
@ -280,7 +296,6 @@ async def get_oauth_client_info_with_dynamic_client_registration(
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"], redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
grant_types=["authorization_code", "refresh_token"], grant_types=["authorization_code", "refresh_token"],
response_types=["code"], response_types=["code"],
token_endpoint_auth_method="client_secret_post",
) )
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes # Attempt to fetch OAuth server metadata to get registration endpoint & scopes
@ -303,6 +318,17 @@ async def get_oauth_client_info_with_dynamic_client_registration(
oauth_client_metadata.scope = " ".join( oauth_client_metadata.scope = " ".join(
oauth_server_metadata.scopes_supported oauth_server_metadata.scopes_supported
) )
if (
oauth_server_metadata.token_endpoint_auth_methods_supported
and oauth_client_metadata.token_endpoint_auth_method
not in oauth_server_metadata.token_endpoint_auth_methods_supported
):
# Pick the first supported method from the server
oauth_client_metadata.token_endpoint_auth_method = oauth_server_metadata.token_endpoint_auth_methods_supported[
0
]
break break
except Exception as e: except Exception as e:
log.error(f"Error parsing OAuth metadata from {url}: {e}") log.error(f"Error parsing OAuth metadata from {url}: {e}")
@ -330,6 +356,13 @@ async def get_oauth_client_info_with_dynamic_client_registration(
registration_response_json = ( registration_response_json = (
await oauth_client_registration_response.json() await oauth_client_registration_response.json()
) )
# The mcp package requires optional unset values to be None. If an empty string is passed, it gets validated and fails.
# This replaces all empty strings with None.
registration_response_json = {
k: (None if v == "" else v)
for k, v in registration_response_json.items()
}
oauth_client_info = OAuthClientInformationFull.model_validate( oauth_client_info = OAuthClientInformationFull.model_validate(
{ {
**registration_response_json, **registration_response_json,
@ -374,9 +407,20 @@ class OAuthClientManager:
"name": client_id, "name": client_id,
"client_id": oauth_client_info.client_id, "client_id": oauth_client_info.client_id,
"client_secret": oauth_client_info.client_secret, "client_secret": oauth_client_info.client_secret,
"client_kwargs": ( "client_kwargs": {
{"scope": oauth_client_info.scope} if oauth_client_info.scope else {} **(
), {"scope": oauth_client_info.scope}
if oauth_client_info.scope
else {}
),
**(
{
"token_endpoint_auth_method": oauth_client_info.token_endpoint_auth_method
}
if oauth_client_info.token_endpoint_auth_method
else {}
),
},
"server_metadata_url": ( "server_metadata_url": (
oauth_client_info.issuer if oauth_client_info.issuer else None oauth_client_info.issuer if oauth_client_info.issuer else None
), ),
@ -690,16 +734,17 @@ class OAuthClientManager:
error_message = None error_message = None
try: try:
client_info = self.get_client_info(client_id) client_info = self.get_client_info(client_id)
token_params = {}
auth_params = {}
if ( if (
client_info client_info
and hasattr(client_info, "client_id") and hasattr(client_info, "client_id")
and hasattr(client_info, "client_secret") and hasattr(client_info, "client_secret")
): ):
token_params["client_id"] = client_info.client_id auth_params["client_id"] = client_info.client_id
token_params["client_secret"] = client_info.client_secret auth_params["client_secret"] = client_info.client_secret
token = await client.authorize_access_token(request, **token_params) token = await client.authorize_access_token(request, **auth_params)
if token: if token:
try: try:
# Add timestamp for tracking # Add timestamp for tracking
@ -978,6 +1023,10 @@ class OAuthManager:
for nested_claim in nested_claims: for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {}) claim_data = claim_data.get(nested_claim, {})
# Try flat claim structure as alternative
if not claim_data:
claim_data = user_data.get(oauth_claim, {})
oauth_roles = [] oauth_roles = []
if isinstance(claim_data, list): if isinstance(claim_data, list):
@ -1111,22 +1160,21 @@ class OAuthManager:
f"Removing user from group {group_model.name} as it is no longer in their oauth groups" f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
) )
user_ids = group_model.user_ids Groups.remove_users_from_group(group_model.id, [user.id])
user_ids = [i for i in user_ids if i != user.id]
# In case a group is created, but perms are never assigned to the group by hitting "save" # In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions group_permissions = group_model.permissions
if not group_permissions: if not group_permissions:
group_permissions = default_permissions group_permissions = default_permissions
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id( Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False id=group_model.id,
form_data=GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
),
overwrite=False,
) )
# Add user to new groups # Add user to new groups
@ -1142,22 +1190,21 @@ class OAuthManager:
f"Adding user to group {group_model.name} as it was found in their oauth groups" f"Adding user to group {group_model.name} as it was found in their oauth groups"
) )
user_ids = group_model.user_ids Groups.add_users_to_group(group_model.id, [user.id])
user_ids.append(user.id)
# In case a group is created, but perms are never assigned to the group by hitting "save" # In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions group_permissions = group_model.permissions
if not group_permissions: if not group_permissions:
group_permissions = default_permissions group_permissions = default_permissions
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id( Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False id=group_model.id,
form_data=GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
),
overwrite=False,
) )
async def _process_picture_url( async def _process_picture_url(
@ -1224,8 +1271,16 @@ class OAuthManager:
error_message = None error_message = None
try: try:
client = self.get_client(provider) client = self.get_client(provider)
auth_params = {}
if client:
if hasattr(client, "client_id"):
auth_params["client_id"] = client.client_id
if hasattr(client, "client_secret"):
auth_params["client_secret"] = client.client_secret
try: try:
token = await client.authorize_access_token(request) token = await client.authorize_access_token(request, **auth_params)
except Exception as e: except Exception as e:
detailed_error = _build_oauth_callback_error_message(e) detailed_error = _build_oauth_callback_error_message(e)
log.warning( log.warning(

View file

@ -208,20 +208,21 @@ def rag_template(template: str, context: str, query: str):
if "[query]" in context: if "[query]" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder) template = template.replace("[query]", query_placeholder)
query_placeholders.append(query_placeholder) query_placeholders.append((query_placeholder, "[query]"))
if "{{QUERY}}" in context: if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder) template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder) query_placeholders.append((query_placeholder, "{{QUERY}}"))
template = template.replace("[context]", context) template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context) template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query) template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query) template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders: for query_placeholder, original_placeholder in query_placeholders:
template = template.replace(query_placeholder, query) template = template.replace(query_placeholder, original_placeholder)
return template return template

View file

@ -99,6 +99,9 @@ def _build_meter_provider(resource: Resource) -> MeterProvider:
View( View(
instrument_name="webui.users.active", instrument_name="webui.users.active",
), ),
View(
instrument_name="webui.users.active.today",
),
] ]
provider = MeterProvider( provider = MeterProvider(
@ -159,6 +162,18 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None:
callbacks=[observe_active_users], callbacks=[observe_active_users],
) )
def observe_users_active_today(
options: metrics.CallbackOptions,
) -> Sequence[metrics.Observation]:
return [metrics.Observation(value=Users.get_num_users_active_today())]
meter.create_observable_gauge(
name="webui.users.active.today",
description="Number of users active since midnight today",
unit="users",
callbacks=[observe_users_active_today],
)
# FastAPI middleware # FastAPI middleware
@app.middleware("http") @app.middleware("http")
async def _metrics_middleware(request: Request, call_next): async def _metrics_middleware(request: Request, call_next):

View file

@ -155,7 +155,9 @@ async def get_tools(
auth_type = tool_server_connection.get("auth_type", "bearer") auth_type = tool_server_connection.get("auth_type", "bearer")
cookies = {} cookies = {}
headers = {} headers = {
"Content-Type": "application/json",
}
if auth_type == "bearer": if auth_type == "bearer":
headers["Authorization"] = ( headers["Authorization"] = (
@ -177,7 +179,10 @@ async def get_tools(
f"Bearer {oauth_token.get('access_token', '')}" f"Bearer {oauth_token.get('access_token', '')}"
) )
headers["Content-Type"] = "application/json" connection_headers = tool_server_connection.get("headers", None)
if connection_headers and isinstance(connection_headers, dict):
for key, value in connection_headers.items():
headers[key] = value
def make_tool_function( def make_tool_function(
function_name, tool_server_data, headers function_name, tool_server_data, headers
@ -232,14 +237,16 @@ async def get_tools(
module, _ = load_tool_module_by_id(tool_id) module, _ = load_tool_module_by_id(tool_id)
request.app.state.TOOLS[tool_id] = module request.app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id __user__ = {
**extra_params["__user__"],
}
# Set valves for the tool # Set valves for the tool
if hasattr(module, "valves") and hasattr(module, "Valves"): if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {} valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves) module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"): if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore __user__["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
) )
@ -261,7 +268,12 @@ async def get_tools(
function_name = spec["name"] function_name = spec["name"]
tool_function = getattr(module, function_name) tool_function = getattr(module, function_name)
callable = get_async_tool_function_and_apply_extra_params( callable = get_async_tool_function_and_apply_extra_params(
tool_function, extra_params tool_function,
{
**extra_params,
"__id__": tool_id,
"__user__": __user__,
},
) )
# TODO: Support Pydantic models as parameters # TODO: Support Pydantic models as parameters
@ -561,20 +573,21 @@ async def get_tool_servers(request: Request):
return tool_servers return tool_servers
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]:
headers = { _headers = {
"Accept": "application/json", "Accept": "application/json",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
if token:
headers["Authorization"] = f"Bearer {token}" if headers:
_headers.update(headers)
error = None error = None
try: try:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get( async with session.get(
url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
) as response: ) as response:
if response.status != 200: if response.status != 200:
error_body = await response.json() error_body = await response.json()
@ -644,7 +657,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
openapi_path = server.get("path", "openapi.json") openapi_path = server.get("path", "openapi.json")
spec_url = get_tool_server_url(server_url, openapi_path) spec_url = get_tool_server_url(server_url, openapi_path)
# Fetch from URL # Fetch from URL
task = get_tool_server_data(token, spec_url) task = get_tool_server_data(
spec_url,
{"Authorization": f"Bearer {token}"} if token else None,
)
elif spec_type == "json" and server.get("spec", ""): elif spec_type == "json" and server.get("spec", ""):
# Use provided JSON spec # Use provided JSON spec
spec_json = None spec_json = None

View file

@ -0,0 +1,50 @@
# Minimal requirements for backend to run
# WIP: use this as a reference to build a minimal docker image
fastapi==0.118.0
uvicorn[standard]==0.37.0
pydantic==2.11.9
python-multipart==0.0.20
itsdangerous==2.2.0
python-socketio==5.13.0
python-jose==3.5.0
cryptography
bcrypt==5.0.0
argon2-cffi==25.1.0
PyJWT[crypto]==2.10.1
authlib==1.6.5
requests==2.32.5
aiohttp==3.12.15
async-timeout
aiocache
aiofiles
starlette-compress==1.6.0
httpx[socks,http2,zstd,cli,brotli]==0.28.1
starsessions[redis]==2.2.1
sqlalchemy==2.0.38
alembic==1.14.0
peewee==3.18.1
peewee-migrate==1.12.2
pycrdt==0.12.25
redis
APScheduler==3.10.4
RestrictedPython==8.0
loguru==0.7.3
asgiref==3.8.1
mcp==1.21.2
openai
langchain==0.3.27
langchain-community==0.3.29
fake-useragent==2.2.0
chromadb==1.1.0
black==25.9.0
pydub

View file

@ -37,11 +37,11 @@ asgiref==3.8.1
# AI libraries # AI libraries
tiktoken tiktoken
mcp==1.14.1 mcp==1.21.2
openai openai
anthropic anthropic
google-genai==1.38.0 google-genai==1.52.0
google-generativeai==0.8.5 google-generativeai==0.8.5
langchain==0.3.27 langchain==0.3.27
@ -49,6 +49,7 @@ langchain-community==0.3.29
fake-useragent==2.2.0 fake-useragent==2.2.0
chromadb==1.1.0 chromadb==1.1.0
weaviate-client==4.17.0
opensearch-py==2.8.0 opensearch-py==2.8.0
transformers transformers
@ -63,7 +64,8 @@ fpdf2==2.8.2
pymdown-extensions==10.14.2 pymdown-extensions==10.14.2
docx2txt==0.8 docx2txt==0.8
python-pptx==1.0.2 python-pptx==1.0.2
unstructured==0.18.15 unstructured==0.18.18
msoffcrypto-tool==5.4.2
nltk==3.9.1 nltk==3.9.1
Markdown==3.9 Markdown==3.9
pypandoc==1.15 pypandoc==1.15
@ -75,7 +77,6 @@ validators==0.35.0
psutil psutil
sentencepiece sentencepiece
soundfile==0.13.1 soundfile==0.13.1
azure-ai-documentintelligence==1.0.2
pillow==11.3.0 pillow==11.3.0
opencv-python-headless==4.11.0.86 opencv-python-headless==4.11.0.86
@ -85,7 +86,6 @@ rank-bm25==0.2.2
onnxruntime==1.20.1 onnxruntime==1.20.1
faster-whisper==1.1.1 faster-whisper==1.1.1
black==25.9.0 black==25.9.0
youtube-transcript-api==1.2.2 youtube-transcript-api==1.2.2
pytube==15.0.0 pytube==15.0.0
@ -93,6 +93,11 @@ pytube==15.0.0
pydub pydub
ddgs==9.0.0 ddgs==9.0.0
azure-ai-documentintelligence==1.0.2
azure-identity==1.25.0
azure-storage-blob==12.24.1
azure-search-documents==11.6.0
## Google Drive ## Google Drive
google-api-python-client google-api-python-client
google-auth-httplib2 google-auth-httplib2
@ -101,10 +106,7 @@ google-auth-oauthlib
googleapis-common-protos==1.70.0 googleapis-common-protos==1.70.0
google-cloud-storage==2.19.0 google-cloud-storage==2.19.0
azure-identity==1.25.0 ## Databases
azure-storage-blob==12.24.1
pymongo pymongo
psycopg2-binary==2.9.10 psycopg2-binary==2.9.10
pgvector==0.4.1 pgvector==0.4.1

4
package-lock.json generated
View file

@ -1,12 +1,12 @@
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.6.36", "version": "0.6.37",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.6.36", "version": "0.6.37",
"dependencies": { "dependencies": {
"@azure/msal-browser": "^4.5.0", "@azure/msal-browser": "^4.5.0",
"@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-javascript": "^6.2.2",

View file

@ -1,6 +1,6 @@
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.6.36", "version": "0.6.37",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",

View file

@ -37,9 +37,6 @@ dependencies = [
"pycrdt==0.12.25", "pycrdt==0.12.25",
"redis", "redis",
"PyMySQL==1.1.1",
"boto3==1.40.5",
"APScheduler==3.10.4", "APScheduler==3.10.4",
"RestrictedPython==8.0", "RestrictedPython==8.0",
@ -47,11 +44,11 @@ dependencies = [
"asgiref==3.8.1", "asgiref==3.8.1",
"tiktoken", "tiktoken",
"mcp==1.14.1", "mcp==1.21.2",
"openai", "openai",
"anthropic", "anthropic",
"google-genai==1.38.0", "google-genai==1.52.0",
"google-generativeai==0.8.5", "google-generativeai==0.8.5",
"langchain==0.3.27", "langchain==0.3.27",
@ -60,6 +57,8 @@ dependencies = [
"fake-useragent==2.2.0", "fake-useragent==2.2.0",
"chromadb==1.0.20", "chromadb==1.0.20",
"opensearch-py==2.8.0", "opensearch-py==2.8.0",
"PyMySQL==1.1.1",
"boto3==1.40.5",
"transformers", "transformers",
"sentence-transformers==5.1.1", "sentence-transformers==5.1.1",
@ -73,7 +72,8 @@ dependencies = [
"pymdown-extensions==10.14.2", "pymdown-extensions==10.14.2",
"docx2txt==0.8", "docx2txt==0.8",
"python-pptx==1.0.2", "python-pptx==1.0.2",
"unstructured==0.18.15", "unstructured==0.18.18",
"msoffcrypto-tool==5.4.2",
"nltk==3.9.1", "nltk==3.9.1",
"Markdown==3.9", "Markdown==3.9",
"pypandoc==1.15", "pypandoc==1.15",
@ -146,12 +146,14 @@ all = [
"elasticsearch==9.1.0", "elasticsearch==9.1.0",
"qdrant-client==1.14.3", "qdrant-client==1.14.3",
"weaviate-client==4.17.0",
"pymilvus==2.6.2", "pymilvus==2.6.2",
"pinecone==6.0.2", "pinecone==6.0.2",
"oracledb==3.2.0", "oracledb==3.2.0",
"colbert-ai==0.2.21", "colbert-ai==0.2.21",
"firecrawl-py==4.5.0", "firecrawl-py==4.5.0",
"azure-search-documents==11.6.0",
] ]
[project.scripts] [project.scripts]

View file

@ -30,8 +30,33 @@
font-display: swap; font-display: swap;
} }
/* --app-text-scale is updated via the UI Scale slider (Interface.svelte) */
:root {
--app-text-scale: 1;
}
html { html {
word-break: break-word; word-break: break-word;
/* font-size scales the entire document via the same UI control */
font-size: calc(1rem * var(--app-text-scale, 1));
}
#sidebar-chat-item {
/* sidebar item sizing scales for the chat list entries */
min-height: calc(32px * var(--app-text-scale, 1));
padding-inline: calc(11px * var(--app-text-scale, 1));
padding-block: calc(6px * var(--app-text-scale, 1));
}
#sidebar-chat-item div[dir='auto'] {
/* chat title line height follows the text scale */
height: calc(20px * var(--app-text-scale, 1));
line-height: calc(20px * var(--app-text-scale, 1));
}
#sidebar-chat-item input {
/* editing state input height is kept in sync */
min-height: calc(20px * var(--app-text-scale, 1));
} }
code { code {

View file

@ -174,7 +174,6 @@
</span> --> </span> -->
</div> </div>
</body> </body>
</html>
<style type="text/css" nonce=""> <style type="text/css" nonce="">
html { html {
@ -243,3 +242,5 @@
animation: pulse 1.5s cubic-bezier(0.4, 0, 0.6, 1) infinite; animation: pulse 1.5s cubic-bezier(0.4, 0, 0.6, 1) infinite;
} }
</style> </style>
</html>

View file

@ -65,15 +65,7 @@ export const unarchiveAllChats = async (token: string) => {
return res; return res;
}; };
export const importChat = async ( export const importChats = async (token: string, chats: object[]) => {
token: string,
chat: object,
meta: object | null,
pinned?: boolean,
folderId?: string | null,
createdAt: number | null = null,
updatedAt: number | null = null
) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/import`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/import`, {
@ -84,12 +76,7 @@ export const importChat = async (
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify({
chat: chat, chats
meta: meta ?? {},
pinned: pinned,
folder_id: folderId,
created_at: createdAt ?? null,
updated_at: updatedAt ?? null
}) })
}) })
.then(async (res) => { .then(async (res) => {

View file

@ -93,6 +93,45 @@ export const getAllFeedbacks = async (token: string = '') => {
return res; return res;
}; };
export const getFeedbackItems = async (token: string = '', orderBy, direction, page) => {
let error = null;
const searchParams = new URLSearchParams();
if (orderBy) searchParams.append('order_by', orderBy);
if (direction) searchParams.append('direction', direction);
if (page) searchParams.append('page', page.toString());
const res = await fetch(
`${WEBUI_API_BASE_URL}/evaluations/feedbacks/list?${searchParams.toString()}`,
{
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
}
)
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.error(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const exportAllFeedbacks = async (token: string = '') => { export const exportAllFeedbacks = async (token: string = '') => {
let error = null; let error = null;

View file

@ -239,10 +239,13 @@ export const updateFolderItemsById = async (token: string, id: string, items: Fo
return res; return res;
}; };
export const deleteFolderById = async (token: string, id: string) => { export const deleteFolderById = async (token: string, id: string, deleteContents: boolean) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}`, { const searchParams = new URLSearchParams();
searchParams.append('delete_contents', deleteContents ? 'true' : 'false');
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}?${searchParams.toString()}`, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',

View file

@ -31,10 +31,15 @@ export const createNewGroup = async (token: string, group: object) => {
return res; return res;
}; };
export const getGroups = async (token: string = '') => { export const getGroups = async (token: string = '', share?: boolean) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/groups/`, { const searchParams = new URLSearchParams();
if (share !== undefined) {
searchParams.append('share', String(share));
}
const res = await fetch(`${WEBUI_API_BASE_URL}/groups/?${searchParams.toString()}`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -160,3 +165,73 @@ export const deleteGroupById = async (token: string, id: string) => {
return res; return res;
}; };
export const addUserToGroup = async (token: string, id: string, userIds: string[]) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/users/add`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
user_ids: userIds
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.error(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const removeUserFromGroup = async (token: string, id: string, userIds: string[]) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/users/remove`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
user_ids: userIds
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.error(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -1425,7 +1425,7 @@ export const getVersion = async (token: string) => {
throw error; throw error;
} }
return res?.version ?? null; return res;
}; };
export const getVersionUpdates = async (token: string) => { export const getVersionUpdates = async (token: string) => {

View file

@ -1,9 +1,68 @@
import { WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL } from '$lib/constants';
export const getModelItems = async (token: string = '') => { export const getModelItems = async (
token: string = '',
query,
viewOption,
selectedTag,
orderBy,
direction,
page
) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/models/list`, { const searchParams = new URLSearchParams();
if (query) {
searchParams.append('query', query);
}
if (viewOption) {
searchParams.append('view_option', viewOption);
}
if (selectedTag) {
searchParams.append('tag', selectedTag);
}
if (orderBy) {
searchParams.append('order_by', orderBy);
}
if (direction) {
searchParams.append('direction', direction);
}
if (page) {
searchParams.append('page', page.toString());
}
const res = await fetch(`${WEBUI_API_BASE_URL}/models/list?${searchParams.toString()}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.error(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getModelTags = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/models/tags`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -192,17 +251,14 @@ export const toggleModelById = async (token: string, id: string) => {
export const updateModelById = async (token: string, id: string, model: object) => { export const updateModelById = async (token: string, id: string, model: object) => {
let error = null; let error = null;
const searchParams = new URLSearchParams(); const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/update`, {
searchParams.append('id', id);
const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/update?${searchParams.toString()}`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify(model) body: JSON.stringify({ ...model, id })
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
@ -228,16 +284,14 @@ export const updateModelById = async (token: string, id: string, model: object)
export const deleteModelById = async (token: string, id: string) => { export const deleteModelById = async (token: string, id: string) => {
let error = null; let error = null;
const searchParams = new URLSearchParams(); const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/delete`, {
searchParams.append('id', id); method: 'POST',
const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/delete?${searchParams.toString()}`, {
method: 'DELETE',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
} },
body: JSON.stringify({ id })
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();

View file

@ -179,39 +179,3 @@ export const downloadDatabase = async (token: string) => {
throw error; throw error;
} }
}; };
export const downloadLiteLLMConfig = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/utils/litellm/config`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (response) => {
if (!response.ok) {
throw await response.json();
}
return response.blob();
})
.then((blob) => {
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'config.yaml';
document.body.appendChild(a);
a.click();
window.URL.revokeObjectURL(url);
})
.catch((err) => {
console.error(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
};

View file

@ -426,7 +426,7 @@
<div class="flex-1"> <div class="flex-1">
<Tooltip <Tooltip
content={$i18n.t( content={$i18n.t(
'Enter additional headers in JSON format (e.g. {{\'{{"X-Custom-Header": "value"}}\'}})' 'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
)} )}
> >
<Textarea <Textarea

View file

@ -22,6 +22,7 @@
import AccessControl from './workspace/common/AccessControl.svelte'; import AccessControl from './workspace/common/AccessControl.svelte';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import XMark from '$lib/components/icons/XMark.svelte'; import XMark from '$lib/components/icons/XMark.svelte';
import Textarea from './common/Textarea.svelte';
export let onSubmit: Function = () => {}; export let onSubmit: Function = () => {};
export let onDelete: Function = () => {}; export let onDelete: Function = () => {};
@ -44,6 +45,7 @@
let auth_type = 'bearer'; let auth_type = 'bearer';
let key = ''; let key = '';
let headers = '';
let accessControl = {}; let accessControl = {};
@ -110,6 +112,20 @@
} }
} }
if (headers) {
try {
let _headers = JSON.parse(headers);
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
_headers = null;
throw new Error('Headers must be a valid JSON object');
}
headers = JSON.stringify(_headers, null, 2);
} catch (error) {
toast.error($i18n.t('Headers must be a valid JSON object'));
return;
}
}
if (direct) { if (direct) {
const res = await getToolServerData( const res = await getToolServerData(
auth_type === 'bearer' ? key : localStorage.token, auth_type === 'bearer' ? key : localStorage.token,
@ -128,6 +144,7 @@
path, path,
type, type,
auth_type, auth_type,
headers: headers ? JSON.parse(headers) : undefined,
key, key,
config: { config: {
enable: enable, enable: enable,
@ -177,6 +194,7 @@
if (data.path) path = data.path; if (data.path) path = data.path;
if (data.auth_type) auth_type = data.auth_type; if (data.auth_type) auth_type = data.auth_type;
if (data.headers) headers = JSON.stringify(data.headers, null, 2);
if (data.key) key = data.key; if (data.key) key = data.key;
if (data.info) { if (data.info) {
@ -210,6 +228,7 @@
path, path,
auth_type, auth_type,
headers: headers ? JSON.parse(headers) : undefined,
key, key,
info: { info: {
@ -256,6 +275,19 @@
} }
} }
if (headers) {
try {
const _headers = JSON.parse(headers);
if (typeof _headers !== 'object' || Array.isArray(_headers)) {
throw new Error('Headers must be a valid JSON object');
}
headers = JSON.stringify(_headers, null, 2);
} catch (error) {
toast.error($i18n.t('Headers must be a valid JSON object'));
return;
}
}
const connection = { const connection = {
type, type,
url, url,
@ -265,9 +297,12 @@
path, path,
auth_type, auth_type,
headers: headers ? JSON.parse(headers) : undefined,
key, key,
config: { config: {
enable: enable, enable: enable,
access_control: accessControl access_control: accessControl
}, },
info: { info: {
@ -313,6 +348,8 @@
path = connection?.path ?? 'openapi.json'; path = connection?.path ?? 'openapi.json';
auth_type = connection?.auth_type ?? 'bearer'; auth_type = connection?.auth_type ?? 'bearer';
headers = connection?.headers ? JSON.stringify(connection.headers, null, 2) : '';
key = connection?.key ?? ''; key = connection?.key ?? '';
id = connection.info?.id ?? ''; id = connection.info?.id ?? '';
@ -657,6 +694,33 @@
</div> </div>
{#if !direct} {#if !direct}
<div class="flex gap-2 mt-2">
<div class="flex flex-col w-full">
<label
for="headers-input"
class={`mb-0.5 text-xs text-gray-500
${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : ''}`}
>{$i18n.t('Headers')}</label
>
<div class="flex-1">
<Tooltip
content={$i18n.t(
'Enter additional headers in JSON format (e.g. {"X-Custom-Header": "value"}'
)}
>
<Textarea
className="w-full text-sm outline-hidden"
bind:value={headers}
placeholder={$i18n.t('Enter additional headers in JSON format')}
required={false}
minSize={30}
/>
</Tooltip>
</div>
</div>
</div>
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" /> <hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div class="flex gap-2"> <div class="flex gap-2">

View file

@ -33,7 +33,9 @@
let feedbacks = []; let feedbacks = [];
onMount(async () => { onMount(async () => {
// TODO: feedbacks elo rating calculation should be done in the backend; remove below line later
feedbacks = await getAllFeedbacks(localStorage.token); feedbacks = await getAllFeedbacks(localStorage.token);
loaded = true; loaded = true;
const containerElement = document.getElementById('users-tabs-container'); const containerElement = document.getElementById('users-tabs-container');
@ -117,7 +119,7 @@
{#if selectedTab === 'leaderboard'} {#if selectedTab === 'leaderboard'}
<Leaderboard {feedbacks} /> <Leaderboard {feedbacks} />
{:else if selectedTab === 'feedbacks'} {:else if selectedTab === 'feedbacks'}
<Feedbacks {feedbacks} /> <Feedbacks />
{/if} {/if}
</div> </div>
</div> </div>

View file

@ -10,7 +10,7 @@
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
import { deleteFeedbackById, exportAllFeedbacks, getAllFeedbacks } from '$lib/apis/evaluations'; import { deleteFeedbackById, exportAllFeedbacks, getFeedbackItems } from '$lib/apis/evaluations';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Download from '$lib/components/icons/Download.svelte'; import Download from '$lib/components/icons/Download.svelte';
@ -23,78 +23,25 @@
import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { config } from '$lib/stores'; import { config } from '$lib/stores';
import Spinner from '$lib/components/common/Spinner.svelte';
export let feedbacks = [];
let page = 1; let page = 1;
$: paginatedFeedbacks = sortedFeedbacks.slice((page - 1) * 10, page * 10); let items = null;
let total = null;
let orderBy: string = 'updated_at'; let orderBy: string = 'updated_at';
let direction: 'asc' | 'desc' = 'desc'; let direction: 'asc' | 'desc' = 'desc';
type Feedback = { const setSortKey = (key) => {
id: string;
data: {
rating: number;
model_id: string;
sibling_model_ids: string[] | null;
reason: string;
comment: string;
tags: string[];
};
user: {
name: string;
profile_image_url: string;
};
updated_at: number;
};
type ModelStats = {
rating: number;
won: number;
lost: number;
};
function setSortKey(key: string) {
if (orderBy === key) { if (orderBy === key) {
direction = direction === 'asc' ? 'desc' : 'asc'; direction = direction === 'asc' ? 'desc' : 'asc';
} else { } else {
orderBy = key; orderBy = key;
if (key === 'user' || key === 'model_id') { direction = 'asc';
direction = 'asc';
} else {
direction = 'desc';
}
} }
page = 1; };
}
$: sortedFeedbacks = [...feedbacks].sort((a, b) => {
let aVal, bVal;
switch (orderBy) {
case 'user':
aVal = a.user?.name || '';
bVal = b.user?.name || '';
return direction === 'asc' ? aVal.localeCompare(bVal) : bVal.localeCompare(aVal);
case 'model_id':
aVal = a.data.model_id || '';
bVal = b.data.model_id || '';
return direction === 'asc' ? aVal.localeCompare(bVal) : bVal.localeCompare(aVal);
case 'rating':
aVal = a.data.rating;
bVal = b.data.rating;
return direction === 'asc' ? aVal - bVal : bVal - aVal;
case 'updated_at':
aVal = a.updated_at;
bVal = b.updated_at;
return direction === 'asc' ? aVal - bVal : bVal - aVal;
default:
return 0;
}
});
let showFeedbackModal = false; let showFeedbackModal = false;
let selectedFeedback = null; let selectedFeedback = null;
@ -115,13 +62,41 @@
// //
////////////////////// //////////////////////
const getFeedbacks = async () => {
try {
const res = await getFeedbackItems(localStorage.token, orderBy, direction, page).catch(
(error) => {
toast.error(`${error}`);
return null;
}
);
if (res) {
items = res.items;
total = res.total;
}
} catch (err) {
console.error(err);
}
};
$: if (page) {
getFeedbacks();
}
$: if (orderBy && direction) {
getFeedbacks();
}
const deleteFeedbackHandler = async (feedbackId: string) => { const deleteFeedbackHandler = async (feedbackId: string) => {
const response = await deleteFeedbackById(localStorage.token, feedbackId).catch((err) => { const response = await deleteFeedbackById(localStorage.token, feedbackId).catch((err) => {
toast.error(err); toast.error(err);
return null; return null;
}); });
if (response) { if (response) {
feedbacks = feedbacks.filter((f) => f.id !== feedbackId); toast.success($i18n.t('Feedback deleted successfully'));
page = 1;
getFeedbacks();
} }
}; };
@ -169,256 +144,266 @@
<FeedbackModal bind:show={showFeedbackModal} {selectedFeedback} onClose={closeFeedbackModal} /> <FeedbackModal bind:show={showFeedbackModal} {selectedFeedback} onClose={closeFeedbackModal} />
<div class="mt-0.5 mb-1 gap-1 flex flex-row justify-between"> {#if items === null || total === null}
<div class="flex md:self-center text-lg font-medium px-0.5"> <div class="my-10">
{$i18n.t('Feedback History')} <Spinner className="size-5" />
</div>
{:else}
<div class="mt-0.5 mb-1 gap-1 flex flex-row justify-between">
<div class="flex items-center md:self-center text-xl font-medium px-0.5 gap-2 shrink-0">
<div>
{$i18n.t('Feedback History')}
</div>
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" /> <div class="text-lg font-medium text-gray-500 dark:text-gray-500">
{total}
</div>
</div>
<span class="text-lg font-medium text-gray-500 dark:text-gray-300">{feedbacks.length}</span> {#if total > 0}
<div>
<Tooltip content={$i18n.t('Export')}>
<button
class=" p-2 rounded-xl hover:bg-gray-100 dark:bg-gray-900 dark:hover:bg-gray-850 transition font-medium text-sm flex items-center space-x-1"
on:click={() => {
exportHandler();
}}
>
<Download className="size-3" />
</button>
</Tooltip>
</div>
{/if}
</div> </div>
{#if feedbacks.length > 0} <div class="scrollbar-hidden relative whitespace-nowrap overflow-x-auto max-w-full">
<div> {#if (items ?? []).length === 0}
<Tooltip content={$i18n.t('Export')}> <div class="text-center text-xs text-gray-500 dark:text-gray-400 py-1">
<button {$i18n.t('No feedbacks found')}
class=" p-2 rounded-xl hover:bg-gray-100 dark:bg-gray-900 dark:hover:bg-gray-850 transition font-medium text-sm flex items-center space-x-1" </div>
on:click={() => { {:else}
exportHandler(); <table
}} class="w-full text-sm text-left text-gray-500 dark:text-gray-400 table-auto max-w-full"
> >
<Download className="size-3" /> <thead class="text-xs text-gray-800 uppercase bg-transparent dark:text-gray-200">
</button> <tr class=" border-b-[1.5px] border-gray-50 dark:border-gray-850">
</Tooltip> <th
</div> scope="col"
{/if} class="px-2.5 py-2 cursor-pointer select-none w-3"
</div> on:click={() => setSortKey('user')}
>
<div class="scrollbar-hidden relative whitespace-nowrap overflow-x-auto max-w-full"> <div class="flex gap-1.5 items-center justify-end">
{#if (feedbacks ?? []).length === 0} {$i18n.t('User')}
<div class="text-center text-xs text-gray-500 dark:text-gray-400 py-1"> {#if orderBy === 'user'}
{$i18n.t('No feedbacks found')} <span class="font-normal">
</div> {#if direction === 'asc'}
{:else} <ChevronUp className="size-2" />
<table class="w-full text-sm text-left text-gray-500 dark:text-gray-400 table-auto max-w-full"> {:else}
<thead class="text-xs text-gray-800 uppercase bg-transparent dark:text-gray-200"> <ChevronDown className="size-2" />
<tr class=" border-b-[1.5px] border-gray-50 dark:border-gray-850"> {/if}
<th </span>
scope="col" {:else}
class="px-2.5 py-2 cursor-pointer select-none w-3" <span class="invisible">
on:click={() => setSortKey('user')}
>
<div class="flex gap-1.5 items-center justify-end">
{$i18n.t('User')}
{#if orderBy === 'user'}
<span class="font-normal">
{#if direction === 'asc'}
<ChevronUp className="size-2" /> <ChevronUp className="size-2" />
{:else} </span>
<ChevronDown className="size-2" /> {/if}
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th
scope="col"
class="px-2.5 py-2 cursor-pointer select-none"
on:click={() => setSortKey('model_id')}
>
<div class="flex gap-1.5 items-center">
{$i18n.t('Models')}
{#if orderBy === 'model_id'}
<span class="font-normal">
{#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th
scope="col"
class="px-2.5 py-2 text-right cursor-pointer select-none w-fit"
on:click={() => setSortKey('rating')}
>
<div class="flex gap-1.5 items-center justify-end">
{$i18n.t('Result')}
{#if orderBy === 'rating'}
<span class="font-normal">
{#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th
scope="col"
class="px-2.5 py-2 text-right cursor-pointer select-none w-0"
on:click={() => setSortKey('updated_at')}
>
<div class="flex gap-1.5 items-center justify-end">
{$i18n.t('Updated At')}
{#if orderBy === 'updated_at'}
<span class="font-normal">
{#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th scope="col" class="px-2.5 py-2 text-right cursor-pointer select-none w-0"> </th>
</tr>
</thead>
<tbody class="">
{#each paginatedFeedbacks as feedback (feedback.id)}
<tr
class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-850/50 transition"
on:click={() => openFeedbackModal(feedback)}
>
<td class=" py-0.5 text-right font-semibold">
<div class="flex justify-center">
<Tooltip content={feedback?.user?.name}>
<div class="shrink-0">
<img
src={feedback?.user?.profile_image_url ?? `${WEBUI_BASE_URL}/user.png`}
alt={feedback?.user?.name}
class="size-5 rounded-full object-cover shrink-0"
/>
</div>
</Tooltip>
</div> </div>
</td> </th>
<td class=" py-1 pl-3 flex flex-col"> <th
<div class="flex flex-col items-start gap-0.5 h-full"> scope="col"
<div class="flex flex-col h-full"> class="px-2.5 py-2 cursor-pointer select-none"
{#if feedback.data?.sibling_model_ids} on:click={() => setSortKey('model_id')}
<div class="font-semibold text-gray-600 dark:text-gray-400 flex-1"> >
{feedback.data?.model_id} <div class="flex gap-1.5 items-center">
</div> {$i18n.t('Models')}
{#if orderBy === 'model_id'}
<Tooltip content={feedback.data.sibling_model_ids.join(', ')}> <span class="font-normal">
<div class=" text-[0.65rem] text-gray-600 dark:text-gray-400 line-clamp-1"> {#if direction === 'asc'}
{#if feedback.data.sibling_model_ids.length > 2} <ChevronUp className="size-2" />
<!-- {$i18n.t('and {{COUNT}} more')} --> {:else}
{feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t( <ChevronDown className="size-2" />
'and {{COUNT}} more', {/if}
{ COUNT: feedback.data.sibling_model_ids.length - 2 } </span>
)} {:else}
{:else} <span class="invisible">
{feedback.data.sibling_model_ids.join(', ')} <ChevronUp className="size-2" />
{/if} </span>
</div> {/if}
</Tooltip>
{:else}
<div
class=" text-sm font-medium text-gray-600 dark:text-gray-400 flex-1 py-1.5"
>
{feedback.data?.model_id}
</div>
{/if}
</div>
</div> </div>
</td> </th>
{#if feedback?.data?.rating} <th
<td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max"> scope="col"
<div class=" flex justify-end"> class="px-2.5 py-2 text-right cursor-pointer select-none w-fit"
{#if feedback?.data?.rating.toString() === '1'} on:click={() => setSortKey('rating')}
<Badge type="info" content={$i18n.t('Won')} /> >
{:else if feedback?.data?.rating.toString() === '0'} <div class="flex gap-1.5 items-center justify-end">
<Badge type="muted" content={$i18n.t('Draw')} /> {$i18n.t('Result')}
{:else if feedback?.data?.rating.toString() === '-1'} {#if orderBy === 'rating'}
<Badge type="error" content={$i18n.t('Lost')} /> <span class="font-normal">
{/if} {#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th
scope="col"
class="px-2.5 py-2 text-right cursor-pointer select-none w-0"
on:click={() => setSortKey('updated_at')}
>
<div class="flex gap-1.5 items-center justify-end">
{$i18n.t('Updated At')}
{#if orderBy === 'updated_at'}
<span class="font-normal">
{#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th scope="col" class="px-2.5 py-2 text-right cursor-pointer select-none w-0"> </th>
</tr>
</thead>
<tbody class="">
{#each items as feedback (feedback.id)}
<tr
class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-850/50 transition"
on:click={() => openFeedbackModal(feedback)}
>
<td class=" py-0.5 text-right font-medium">
<div class="flex justify-center">
<Tooltip content={feedback?.user?.name}>
<div class="shrink-0">
<img
src={`${WEBUI_API_BASE_URL}/users/${feedback.user.id}/profile/image`}
alt={feedback?.user?.name}
class="size-5 rounded-full object-cover shrink-0"
/>
</div>
</Tooltip>
</div> </div>
</td> </td>
{/if}
<td class=" px-3 py-1 text-right font-medium"> <td class=" py-1 pl-3 flex flex-col">
{dayjs(feedback.updated_at * 1000).fromNow()} <div class="flex flex-col items-start gap-0.5 h-full">
</td> <div class="flex flex-col h-full">
{#if feedback.data?.sibling_model_ids}
<div class="font-medium text-gray-600 dark:text-gray-400 flex-1">
{feedback.data?.model_id}
</div>
<td class=" px-3 py-1 text-right font-semibold" on:click={(e) => e.stopPropagation()}> <Tooltip content={feedback.data.sibling_model_ids.join(', ')}>
<FeedbackMenu <div class=" text-[0.65rem] text-gray-600 dark:text-gray-400 line-clamp-1">
on:delete={(e) => { {#if feedback.data.sibling_model_ids.length > 2}
deleteFeedbackHandler(feedback.id); <!-- {$i18n.t('and {{COUNT}} more')} -->
}} {feedback.data.sibling_model_ids.slice(0, 2).join(', ')}, {$i18n.t(
> 'and {{COUNT}} more',
<button { COUNT: feedback.data.sibling_model_ids.length - 2 }
class="self-center w-fit text-sm p-1.5 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl" )}
{:else}
{feedback.data.sibling_model_ids.join(', ')}
{/if}
</div>
</Tooltip>
{:else}
<div
class=" text-sm font-medium text-gray-600 dark:text-gray-400 flex-1 py-1.5"
>
{feedback.data?.model_id}
</div>
{/if}
</div>
</div>
</td>
{#if feedback?.data?.rating}
<td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max">
<div class=" flex justify-end">
{#if feedback?.data?.rating.toString() === '1'}
<Badge type="info" content={$i18n.t('Won')} />
{:else if feedback?.data?.rating.toString() === '0'}
<Badge type="muted" content={$i18n.t('Draw')} />
{:else if feedback?.data?.rating.toString() === '-1'}
<Badge type="error" content={$i18n.t('Lost')} />
{/if}
</div>
</td>
{/if}
<td class=" px-3 py-1 text-right font-medium">
{dayjs(feedback.updated_at * 1000).fromNow()}
</td>
<td class=" px-3 py-1 text-right font-medium" on:click={(e) => e.stopPropagation()}>
<FeedbackMenu
on:delete={(e) => {
deleteFeedbackHandler(feedback.id);
}}
> >
<EllipsisHorizontal /> <button
</button> class="self-center w-fit text-sm p-1.5 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
</FeedbackMenu> >
</td> <EllipsisHorizontal />
</tr> </button>
{/each} </FeedbackMenu>
</tbody> </td>
</table> </tr>
{/if} {/each}
</div> </tbody>
</table>
{#if feedbacks.length > 0 && $config?.features?.enable_community_sharing} {/if}
<div class=" flex flex-col justify-end w-full text-right gap-1">
<div class="line-clamp-1 text-gray-500 text-xs">
{$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')}
</div>
<div class="flex space-x-1 ml-auto">
<Tooltip
content={$i18n.t(
'To protect your privacy, only ratings, model IDs, tags, and metadata are shared from your feedback—your chat logs remain private and are not included.'
)}
>
<button
class="flex text-xs items-center px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-200 transition"
on:click={async () => {
shareHandler();
}}
>
<div class=" self-center mr-2 font-medium line-clamp-1">
{$i18n.t('Share to Open WebUI Community')}
</div>
<div class=" self-center">
<CloudArrowUp className="size-3" strokeWidth="3" />
</div>
</button>
</Tooltip>
</div>
</div> </div>
{/if}
{#if feedbacks.length > 10} {#if total > 0 && $config?.features?.enable_community_sharing}
<Pagination bind:page count={feedbacks.length} perPage={10} /> <div class=" flex flex-col justify-end w-full text-right gap-1">
<div class="line-clamp-1 text-gray-500 text-xs">
{$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')}
</div>
<div class="flex space-x-1 ml-auto">
<Tooltip
content={$i18n.t(
'To protect your privacy, only ratings, model IDs, tags, and metadata are shared from your feedback—your chat logs remain private and are not included.'
)}
>
<button
class="flex text-xs items-center px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-200 transition"
on:click={async () => {
shareHandler();
}}
>
<div class=" self-center mr-2 font-medium line-clamp-1">
{$i18n.t('Share to Open WebUI Community')}
</div>
<div class=" self-center">
<CloudArrowUp className="size-3" strokeWidth="3" />
</div>
</button>
</Tooltip>
</div>
</div>
{/if}
{#if total > 30}
<Pagination bind:page count={total} perPage={30} />
{/if}
{/if} {/if}

View file

@ -10,7 +10,7 @@
import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -339,16 +339,14 @@
<div <div
class="pt-0.5 pb-1 gap-1 flex flex-col md:flex-row justify-between sticky top-0 z-10 bg-white dark:bg-gray-900" class="pt-0.5 pb-1 gap-1 flex flex-col md:flex-row justify-between sticky top-0 z-10 bg-white dark:bg-gray-900"
> >
<div class="flex md:self-center text-lg font-medium px-0.5 shrink-0 items-center"> <div class="flex items-center md:self-center text-xl font-medium px-0.5 gap-2 shrink-0">
<div class=" gap-1"> <div>
{$i18n.t('Leaderboard')} {$i18n.t('Leaderboard')}
</div> </div>
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" /> <div class="text-lg font-medium text-gray-500 dark:text-gray-500">
{rankedModels.length}
<span class="text-lg font-medium text-gray-500 dark:text-gray-300 mr-1.5" </div>
>{rankedModels.length}</span
>
</div> </div>
<div class=" flex space-x-2"> <div class=" flex space-x-2">
@ -517,7 +515,7 @@
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
<div class="shrink-0"> <div class="shrink-0">
<img <img
src={model?.info?.meta?.profile_image_url ?? `${WEBUI_BASE_URL}/favicon.png`} src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${model.id}`}
alt={model.name} alt={model.name}
class="size-5 rounded-full object-cover shrink-0" class="size-5 rounded-full object-cover shrink-0"
/> />
@ -532,7 +530,7 @@
{model.rating} {model.rating}
</td> </td>
<td class=" px-3 py-1.5 text-right font-semibold text-green-500"> <td class=" px-3 py-1.5 text-right font-medium text-green-500">
<div class=" w-10"> <div class=" w-10">
{#if model.stats.won === '-'} {#if model.stats.won === '-'}
- -
@ -545,7 +543,7 @@
</div> </div>
</td> </td>
<td class="px-3 py-1.5 text-right font-semibold text-red-500"> <td class="px-3 py-1.5 text-right font-medium text-red-500">
<div class=" w-10"> <div class=" w-10">
{#if model.stats.lost === '-'} {#if model.stats.lost === '-'}
- -

View file

@ -548,12 +548,7 @@
{:else if TTS_ENGINE === 'elevenlabs'} {:else if TTS_ENGINE === 'elevenlabs'}
<div> <div>
<div class="mt-1 flex gap-2 mb-1"> <div class="mt-1 flex gap-2 mb-1">
<input <SensitiveInput placeholder={$i18n.t('API Key')} bind:value={TTS_API_KEY} required />
class="flex-1 w-full rounded-lg py-2 pl-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden"
placeholder={$i18n.t('API Key')}
bind:value={TTS_API_KEY}
required
/>
</div> </div>
</div> </div>
{:else if TTS_ENGINE === 'azure'} {:else if TTS_ENGINE === 'azure'}

View file

@ -2,7 +2,7 @@
import fileSaver from 'file-saver'; import fileSaver from 'file-saver';
const { saveAs } = fileSaver; const { saveAs } = fileSaver;
import { downloadDatabase, downloadLiteLLMConfig } from '$lib/apis/utils'; import { downloadDatabase } from '$lib/apis/utils';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { config, user } from '$lib/stores'; import { config, user } from '$lib/stores';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';

View file

@ -212,6 +212,18 @@
await embeddingModelUpdateHandler(); await embeddingModelUpdateHandler();
} }
if (RAGConfig.DOCLING_PARAMS) {
try {
JSON.parse(RAGConfig.DOCLING_PARAMS);
} catch (e) {
toast.error(
$i18n.t('Invalid JSON format in {{NAME}}', {
NAME: $i18n.t('Docling Parameters')
})
);
return;
}
}
if (RAGConfig.MINERU_PARAMS) { if (RAGConfig.MINERU_PARAMS) {
try { try {
JSON.parse(RAGConfig.MINERU_PARAMS); JSON.parse(RAGConfig.MINERU_PARAMS);
@ -232,6 +244,10 @@
DOCLING_PICTURE_DESCRIPTION_API: JSON.parse( DOCLING_PICTURE_DESCRIPTION_API: JSON.parse(
RAGConfig.DOCLING_PICTURE_DESCRIPTION_API || '{}' RAGConfig.DOCLING_PICTURE_DESCRIPTION_API || '{}'
), ),
DOCLING_PARAMS:
typeof RAGConfig.DOCLING_PARAMS === 'string' && RAGConfig.DOCLING_PARAMS.trim() !== ''
? JSON.parse(RAGConfig.DOCLING_PARAMS)
: {},
MINERU_PARAMS: MINERU_PARAMS:
typeof RAGConfig.MINERU_PARAMS === 'string' && RAGConfig.MINERU_PARAMS.trim() !== '' typeof RAGConfig.MINERU_PARAMS === 'string' && RAGConfig.MINERU_PARAMS.trim() !== ''
? JSON.parse(RAGConfig.MINERU_PARAMS) ? JSON.parse(RAGConfig.MINERU_PARAMS)
@ -275,6 +291,10 @@
null, null,
2 2
); );
config.DOCLING_PARAMS =
typeof config.DOCLING_PARAMS === 'object'
? JSON.stringify(config.DOCLING_PARAMS ?? {}, null, 2)
: config.DOCLING_PARAMS;
config.MINERU_PARAMS = config.MINERU_PARAMS =
typeof config.MINERU_PARAMS === 'object' typeof config.MINERU_PARAMS === 'object'
@ -737,18 +757,18 @@
{/if} {/if}
{/if} {/if}
<div class="flex justify-between w-full mt-2"> <div class="flex flex-col gap-2 mt-2">
<div class="self-center text-xs font-medium"> <div class=" flex flex-col w-full justify-between">
<Tooltip content={''} placement="top-start"> <div class=" mb-1 text-xs font-medium">
{$i18n.t('Parameters')} {$i18n.t('Parameters')}
</Tooltip> </div>
</div> <div class="flex w-full items-center relative">
<div class=""> <Textarea
<Textarea bind:value={RAGConfig.DOCLING_PARAMS}
bind:value={RAGConfig.DOCLING_PARAMS} placeholder={$i18n.t('Enter additional parameters in JSON format')}
placeholder={$i18n.t('Enter additional parameters in JSON format')} minSize={100}
minSize={100} />
/> </div>
</div> </div>
</div> </div>
{:else if RAGConfig.CONTENT_EXTRACTION_ENGINE === 'document_intelligence'} {:else if RAGConfig.CONTENT_EXTRACTION_ENGINE === 'document_intelligence'}
@ -823,6 +843,7 @@
<SensitiveInput <SensitiveInput
placeholder={$i18n.t('Enter MinerU API Key')} placeholder={$i18n.t('Enter MinerU API Key')}
bind:value={RAGConfig.MINERU_API_KEY} bind:value={RAGConfig.MINERU_API_KEY}
required={false}
/> />
</div> </div>
@ -1131,6 +1152,21 @@
</div> </div>
{#if RAGConfig.ENABLE_RAG_HYBRID_SEARCH === true} {#if RAGConfig.ENABLE_RAG_HYBRID_SEARCH === true}
<div class="mb-2.5 flex w-full justify-between">
<div class="self-center text-xs font-medium">
{$i18n.t('Enrich Hybrid Search Text')}
</div>
<div class="flex items-center relative">
<Tooltip
content={$i18n.t(
'Adds filenames, titles, sections, and snippets into the BM25 text to improve lexical recall.'
)}
>
<Switch bind:state={RAGConfig.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS} />
</Tooltip>
</div>
</div>
<div class=" mb-2.5 flex flex-col w-full justify-between"> <div class=" mb-2.5 flex flex-col w-full justify-between">
<div class="flex w-full justify-between"> <div class="flex w-full justify-between">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">

View file

@ -5,6 +5,7 @@
import Cog6 from '$lib/components/icons/Cog6.svelte'; import Cog6 from '$lib/components/icons/Cog6.svelte';
import ArenaModelModal from './ArenaModelModal.svelte'; import ArenaModelModal from './ArenaModelModal.svelte';
import { WEBUI_API_BASE_URL } from '$lib/constants';
export let model; export let model;
let showModel = false; let showModel = false;
@ -27,7 +28,7 @@
<div class="flex flex-col flex-1"> <div class="flex flex-col flex-1">
<div class="flex gap-2.5 items-center"> <div class="flex gap-2.5 items-center">
<img <img
src={model.meta.profile_image_url} src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${model.id}`}
alt={model.name} alt={model.name}
class="size-8 rounded-full object-cover shrink-0" class="size-8 rounded-full object-cover shrink-0"
/> />

View file

@ -10,6 +10,7 @@
updateLdapConfig, updateLdapConfig,
updateLdapServer updateLdapServer
} from '$lib/apis/auths'; } from '$lib/apis/auths';
import { getGroups } from '$lib/apis/groups';
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
import Switch from '$lib/components/common/Switch.svelte'; import Switch from '$lib/components/common/Switch.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
@ -32,6 +33,7 @@
let adminConfig = null; let adminConfig = null;
let webhookUrl = ''; let webhookUrl = '';
let groups = [];
// LDAP // LDAP
let ENABLE_LDAP = false; let ENABLE_LDAP = false;
@ -104,6 +106,9 @@
})(), })(),
(async () => { (async () => {
LDAP_SERVER = await getLdapServer(localStorage.token); LDAP_SERVER = await getLdapServer(localStorage.token);
})(),
(async () => {
groups = await getGroups(localStorage.token);
})() })()
]); ]);
@ -299,6 +304,22 @@
</div> </div>
</div> </div>
<div class=" mb-2.5 flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('Default Group')}</div>
<div class="flex items-center relative">
<select
class="dark:bg-gray-900 w-fit pr-8 rounded-sm px-2 text-xs bg-transparent outline-hidden text-right"
bind:value={adminConfig.DEFAULT_GROUP_ID}
placeholder={$i18n.t('Select a group')}
>
<option value={''}>None</option>
{#each groups as group}
<option value={group.id}>{group.name}</option>
{/each}
</select>
</div>
</div>
<div class=" mb-2.5 flex w-full justify-between pr-2"> <div class=" mb-2.5 flex w-full justify-between pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Enable New Sign Ups')}</div> <div class=" self-center text-xs font-medium">{$i18n.t('Enable New Sign Ups')}</div>
@ -338,31 +359,31 @@
</div> </div>
<div class="mb-2.5 flex w-full justify-between pr-2"> <div class="mb-2.5 flex w-full justify-between pr-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Enable API Key')}</div> <div class=" self-center text-xs font-medium">{$i18n.t('Enable API Keys')}</div>
<Switch bind:state={adminConfig.ENABLE_API_KEY} /> <Switch bind:state={adminConfig.ENABLE_API_KEYS} />
</div> </div>
{#if adminConfig?.ENABLE_API_KEY} {#if adminConfig?.ENABLE_API_KEYS}
<div class="mb-2.5 flex w-full justify-between pr-2"> <div class="mb-2.5 flex w-full justify-between pr-2">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('API Key Endpoint Restrictions')} {$i18n.t('API Key Endpoint Restrictions')}
</div> </div>
<Switch bind:state={adminConfig.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS} /> <Switch bind:state={adminConfig.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS} />
</div> </div>
{#if adminConfig?.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS} {#if adminConfig?.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS}
<div class=" flex w-full flex-col pr-2"> <div class=" flex w-full flex-col pr-2 mb-2.5">
<div class=" text-xs font-medium"> <div class=" text-xs font-medium">
{$i18n.t('Allowed Endpoints')} {$i18n.t('Allowed Endpoints')}
</div> </div>
<input <input
class="w-full mt-1 rounded-lg text-sm dark:text-gray-300 bg-transparent outline-hidden" class="w-full mt-1 text-sm dark:text-gray-300 bg-transparent outline-hidden"
type="text" type="text"
placeholder={`e.g.) /api/v1/messages, /api/v1/channels`} placeholder={`e.g.) /api/v1/messages, /api/v1/channels`}
bind:value={adminConfig.API_KEY_ALLOWED_ENDPOINTS} bind:value={adminConfig.API_KEYS_ALLOWED_ENDPOINTS}
/> />
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500"> <div class="mt-2 text-xs text-gray-400 dark:text-gray-500">

View file

@ -112,32 +112,41 @@
config.ENABLE_IMAGE_GENERATION = false; config.ENABLE_IMAGE_GENERATION = false;
return null; return null;
} else if (config.IMAGE_GENERATION_ENGINE === 'openai' && config.OPENAI_API_KEY === '') { } else if (config.IMAGE_GENERATION_ENGINE === 'openai' && config.IMAGES_OPENAI_API_KEY === '') {
toast.error($i18n.t('OpenAI API Key is required.')); toast.error($i18n.t('OpenAI API Key is required.'));
config.ENABLE_IMAGE_GENERATION = false; config.ENABLE_IMAGE_GENERATION = false;
return null; return null;
} else if (config.IMAGE_GENERATION_ENGINE === 'gemini' && config.GEMINI_API_KEY === '') { } else if (config.IMAGE_GENERATION_ENGINE === 'gemini' && config.IMAGES_GEMINI_API_KEY === '') {
toast.error($i18n.t('Gemini API Key is required.')); toast.error($i18n.t('Gemini API Key is required.'));
config.ENABLE_IMAGE_GENERATION = false; config.ENABLE_IMAGE_GENERATION = false;
return null; return null;
} }
const res = await updateConfig(localStorage.token, config).catch((error) => { const res = await updateConfig(localStorage.token, {
...config,
AUTOMATIC1111_PARAMS:
typeof config.AUTOMATIC1111_PARAMS === 'string' && config.AUTOMATIC1111_PARAMS.trim() !== ''
? JSON.parse(config.AUTOMATIC1111_PARAMS)
: {},
IMAGES_OPENAI_API_PARAMS:
typeof config.IMAGES_OPENAI_API_PARAMS === 'string' &&
config.IMAGES_OPENAI_API_PARAMS.trim() !== ''
? JSON.parse(config.IMAGES_OPENAI_API_PARAMS)
: {}
}).catch((error) => {
toast.error(`${error}`); toast.error(`${error}`);
return null; return null;
}); });
if (res) { if (res) {
config = res; if (res.ENABLE_IMAGE_GENERATION) {
if (config.ENABLE_IMAGE_GENERATION) {
backendConfig.set(await getBackendConfig()); backendConfig.set(await getBackendConfig());
getModels(); getModels();
} }
return config; return res;
} }
return null; return null;
@ -245,6 +254,16 @@
} }
} }
config.IMAGES_OPENAI_API_PARAMS =
typeof config.IMAGES_OPENAI_API_PARAMS === 'object'
? JSON.stringify(config.IMAGES_OPENAI_API_PARAMS ?? {}, null, 2)
: config.IMAGES_OPENAI_API_PARAMS;
config.AUTOMATIC1111_PARAMS =
typeof config.AUTOMATIC1111_PARAMS === 'object'
? JSON.stringify(config.AUTOMATIC1111_PARAMS ?? {}, null, 2)
: config.AUTOMATIC1111_PARAMS;
REQUIRED_EDIT_WORKFLOW_NODES = REQUIRED_EDIT_WORKFLOW_NODES.map((node) => { REQUIRED_EDIT_WORKFLOW_NODES = REQUIRED_EDIT_WORKFLOW_NODES.map((node) => {
const n = const n =
config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node; config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node;
@ -456,6 +475,26 @@
</div> </div>
</div> </div>
</div> </div>
<div class="mb-2.5">
<div class="flex w-full justify-between items-center">
<div class="text-xs pr-2 shrink-0">
<div class="">
{$i18n.t('Additional Parameters')}
</div>
</div>
</div>
<div class="mt-1.5 flex w-full">
<div class="flex-1 mr-2">
<Textarea
className="rounded-lg w-full py-2 px-3 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden"
bind:value={config.IMAGES_OPENAI_API_PARAMS}
placeholder={$i18n.t('Enter additional parameters in JSON format')}
minSize={100}
/>
</div>
</div>
</div>
{:else if (config?.IMAGE_GENERATION_ENGINE ?? 'automatic1111') === 'automatic1111'} {:else if (config?.IMAGE_GENERATION_ENGINE ?? 'automatic1111') === 'automatic1111'}
<div class="mb-2.5"> <div class="mb-2.5">
<div class="flex w-full justify-between items-center"> <div class="flex w-full justify-between items-center">
@ -849,23 +888,15 @@
<div class="flex w-full justify-between items-center"> <div class="flex w-full justify-between items-center">
<div class="text-xs pr-2"> <div class="text-xs pr-2">
<div class=""> <div class="">
{$i18n.t('Image Edit Engine')} {$i18n.t('Image Edit')}
</div> </div>
</div> </div>
<select <Switch bind:state={config.ENABLE_IMAGE_EDIT} />
class=" dark:bg-gray-900 w-fit pr-8 cursor-pointer rounded-sm px-2 text-xs bg-transparent outline-hidden text-right"
bind:value={config.IMAGE_EDIT_ENGINE}
placeholder={$i18n.t('Select Engine')}
>
<option value="openai">{$i18n.t('Default (Open AI)')}</option>
<option value="comfyui">{$i18n.t('ComfyUI')}</option>
<option value="gemini">{$i18n.t('Gemini')}</option>
</select>
</div> </div>
</div> </div>
{#if config.ENABLE_IMAGE_GENERATION} {#if config?.ENABLE_IMAGE_GENERATION && config?.ENABLE_IMAGE_EDIT}
<div class="mb-2.5"> <div class="mb-2.5">
<div class="flex w-full justify-between items-center"> <div class="flex w-full justify-between items-center">
<div class="text-xs pr-2"> <div class="text-xs pr-2">
@ -910,6 +941,26 @@
</div> </div>
{/if} {/if}
<div class="mb-2.5">
<div class="flex w-full justify-between items-center">
<div class="text-xs pr-2">
<div class="">
{$i18n.t('Image Edit Engine')}
</div>
</div>
<select
class=" dark:bg-gray-900 w-fit pr-8 cursor-pointer rounded-sm px-2 text-xs bg-transparent outline-hidden text-right"
bind:value={config.IMAGE_EDIT_ENGINE}
placeholder={$i18n.t('Select Engine')}
>
<option value="openai">{$i18n.t('Default (Open AI)')}</option>
<option value="comfyui">{$i18n.t('ComfyUI')}</option>
<option value="gemini">{$i18n.t('Gemini')}</option>
</select>
</div>
</div>
{#if config?.IMAGE_EDIT_ENGINE === 'openai'} {#if config?.IMAGE_EDIT_ENGINE === 'openai'}
<div class="mb-2.5"> <div class="mb-2.5">
<div class="flex w-full justify-between items-center"> <div class="flex w-full justify-between items-center">
@ -1170,7 +1221,7 @@
</div> </div>
</div> </div>
{/if} {/if}
{:else if config?.IMAGE_GENERATION_ENGINE === 'gemini'} {:else if config?.IMAGE_EDIT_ENGINE === 'gemini'}
<div class="mb-2.5"> <div class="mb-2.5">
<div class="flex w-full justify-between items-center"> <div class="flex w-full justify-between items-center">
<div class="text-xs pr-2 shrink-0"> <div class="text-xs pr-2 shrink-0">

View file

@ -21,6 +21,7 @@
import Textarea from '$lib/components/common/Textarea.svelte'; import Textarea from '$lib/components/common/Textarea.svelte';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import Banners from './Interface/Banners.svelte'; import Banners from './Interface/Banners.svelte';
import PromptSuggestions from '$lib/components/workspace/Models/PromptSuggestions.svelte';
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
@ -41,7 +42,8 @@
ENABLE_SEARCH_QUERY_GENERATION: true, ENABLE_SEARCH_QUERY_GENERATION: true,
ENABLE_RETRIEVAL_QUERY_GENERATION: true, ENABLE_RETRIEVAL_QUERY_GENERATION: true,
QUERY_GENERATION_PROMPT_TEMPLATE: '', QUERY_GENERATION_PROMPT_TEMPLATE: '',
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: '' TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: '',
VOICE_MODE_PROMPT_TEMPLATE: ''
}; };
let promptSuggestions = []; let promptSuggestions = [];
@ -237,6 +239,41 @@
</div> </div>
{/if} {/if}
<div class="mb-2.5 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Voice Mode Custom Prompt')}
</div>
<Switch
state={taskConfig.VOICE_MODE_PROMPT_TEMPLATE != null}
on:change={(e) => {
if (e.detail) {
taskConfig.VOICE_MODE_PROMPT_TEMPLATE = '';
} else {
taskConfig.VOICE_MODE_PROMPT_TEMPLATE = null;
}
}}
/>
</div>
{#if taskConfig.VOICE_MODE_PROMPT_TEMPLATE != null}
<div class="mb-2.5">
<div class=" mb-1 text-xs font-medium">{$i18n.t('Voice Mode Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={taskConfig.VOICE_MODE_PROMPT_TEMPLATE}
placeholder={$i18n.t(
'Leave empty to use the default prompt, or enter a custom prompt'
)}
/>
</Tooltip>
</div>
{/if}
<div class="mb-2.5 flex w-full items-center justify-between"> <div class="mb-2.5 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Follow Up Generation')} {$i18n.t('Follow Up Generation')}
@ -430,201 +467,13 @@
</div> </div>
{#if $user?.role === 'admin'} {#if $user?.role === 'admin'}
<div class=" space-y-3"> <PromptSuggestions bind:promptSuggestions />
<div class="flex w-full justify-between mb-2">
<div class=" self-center text-xs">
{$i18n.t('Default Prompt Suggestions')}
</div>
<button {#if promptSuggestions.length > 0}
class="p-1 px-3 text-xs flex rounded-sm transition" <div class="text-xs text-left w-full mt-2">
type="button" {$i18n.t('Adjusting these settings will apply changes universally to all users.')}
on:click={() => {
if (promptSuggestions.length === 0 || promptSuggestions.at(-1).content !== '') {
promptSuggestions = [...promptSuggestions, { content: '', title: ['', ''] }];
}
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M10.75 4.75a.75.75 0 00-1.5 0v4.5h-4.5a.75.75 0 000 1.5h4.5v4.5a.75.75 0 001.5 0v-4.5h4.5a.75.75 0 000-1.5h-4.5v-4.5z"
/>
</svg>
</button>
</div> </div>
<div class="grid lg:grid-cols-2 flex-col gap-1.5"> {/if}
{#each promptSuggestions as prompt, promptIdx}
<div
class=" flex border rounded-xl border-gray-50 dark:border-none dark:bg-gray-850 py-1.5"
>
<div class="flex flex-col flex-1 pl-1">
<div class="py-1 gap-1">
<input
class="px-3 text-sm font-medium w-full bg-transparent outline-hidden"
placeholder={$i18n.t('Title (e.g. Tell me a fun fact)')}
bind:value={prompt.title[0]}
/>
<input
class="px-3 text-xs w-full bg-transparent outline-hidden text-gray-600 dark:text-gray-400"
placeholder={$i18n.t('Subtitle (e.g. about the Roman Empire)')}
bind:value={prompt.title[1]}
/>
</div>
<hr class="border-gray-50 dark:border-gray-850 my-1" />
<textarea
class="px-3 py-1.5 text-xs w-full bg-transparent outline-hidden resize-none"
placeholder={$i18n.t(
'Prompt (e.g. Tell me a fun fact about the Roman Empire)'
)}
rows="3"
bind:value={prompt.content}
/>
</div>
<div class="">
<button
class="p-3"
type="button"
on:click={() => {
promptSuggestions.splice(promptIdx, 1);
promptSuggestions = promptSuggestions;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button>
</div>
</div>
{/each}
</div>
{#if promptSuggestions.length > 0}
<div class="text-xs text-left w-full mt-2">
{$i18n.t('Adjusting these settings will apply changes universally to all users.')}
</div>
{/if}
<div class="flex items-center justify-end space-x-2 mt-2">
<input
id="prompt-suggestions-import-input"
type="file"
accept=".json"
hidden
on:change={(e) => {
const files = e.target.files;
if (!files || files.length === 0) {
return;
}
console.log(files);
let reader = new FileReader();
reader.onload = async (event) => {
try {
let suggestions = JSON.parse(event.target.result);
suggestions = suggestions.map((s) => {
if (typeof s.title === 'string') {
s.title = [s.title, ''];
} else if (!Array.isArray(s.title)) {
s.title = ['', ''];
}
return s;
});
promptSuggestions = [...promptSuggestions, ...suggestions];
} catch (error) {
toast.error($i18n.t('Invalid JSON file'));
return;
}
};
reader.readAsText(files[0]);
e.target.value = ''; // Reset the input value
}}
/>
<button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
type="button"
on:click={() => {
const input = document.getElementById('prompt-suggestions-import-input');
if (input) {
input.click();
}
}}
>
<div class=" self-center mr-2 font-medium line-clamp-1">
{$i18n.t('Import Prompt Suggestions')}
</div>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-3.5 h-3.5"
>
<path
fill-rule="evenodd"
d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 9.5a.75.75 0 0 1-.75-.75V8.06l-.72.72a.75.75 0 0 1-1.06-1.06l2-2a.75.75 0 0 1 1.06 0l2 2a.75.75 0 1 1-1.06 1.06l-.72-.72v2.69a.75.75 0 0 1-.75.75Z"
clip-rule="evenodd"
/>
</svg>
</div>
</button>
{#if promptSuggestions.length}
<button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
type="button"
on:click={async () => {
let blob = new Blob([JSON.stringify(promptSuggestions)], {
type: 'application/json'
});
saveAs(blob, `prompt-suggestions-export-${Date.now()}.json`);
}}
>
<div class=" self-center mr-2 font-medium line-clamp-1">
{$i18n.t('Export Prompt Suggestions')} ({promptSuggestions.length})
</div>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-3.5 h-3.5"
>
<path
fill-rule="evenodd"
d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 3.5a.75.75 0 0 1 .75.75v2.69l.72-.72a.75.75 0 1 1 1.06 1.06l-2 2a.75.75 0 0 1-1.06 0l-2-2a.75.75 0 0 1 1.06-1.06l.72.72V6.25A.75.75 0 0 1 8 5.5Z"
clip-rule="evenodd"
/>
</svg>
</div>
</button>
{/if}
</div>
</div>
{/if} {/if}
</div> </div>
</div> </div>

View file

@ -37,7 +37,8 @@
import EllipsisHorizontal from '$lib/components/icons/EllipsisHorizontal.svelte'; import EllipsisHorizontal from '$lib/components/icons/EllipsisHorizontal.svelte';
import EyeSlash from '$lib/components/icons/EyeSlash.svelte'; import EyeSlash from '$lib/components/icons/EyeSlash.svelte';
import Eye from '$lib/components/icons/Eye.svelte'; import Eye from '$lib/components/icons/Eye.svelte';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { goto } from '$app/navigation';
let shiftKey = false; let shiftKey = false;
@ -200,6 +201,16 @@
} }
}; };
const cloneHandler = async (model) => {
sessionStorage.model = JSON.stringify({
...model,
base_model_id: model.id,
id: `${model.id}-clone`,
name: `${model.name} (Clone)`
});
goto('/workspace/models/create');
};
const exportModelHandler = async (model) => { const exportModelHandler = async (model) => {
let blob = new Blob([JSON.stringify([model])], { let blob = new Blob([JSON.stringify([model])], {
type: 'application/json' type: 'application/json'
@ -250,9 +261,8 @@
{#if selectedModelId === null} {#if selectedModelId === null}
<div class="flex flex-col gap-1 mt-1.5 mb-2"> <div class="flex flex-col gap-1 mt-1.5 mb-2">
<div class="flex justify-between items-center"> <div class="flex justify-between items-center">
<div class="flex items-center md:self-center text-xl font-medium px-0.5"> <div class="flex items-center md:self-center text-xl font-medium px-0.5 gap-2">
{$i18n.t('Models')} {$i18n.t('Models')}
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" />
<span class="text-lg font-medium text-gray-500 dark:text-gray-300" <span class="text-lg font-medium text-gray-500 dark:text-gray-300"
>{filteredModels.length}</span >{filteredModels.length}</span
> >
@ -335,7 +345,7 @@
: 'opacity-50 dark:opacity-50'} " : 'opacity-50 dark:opacity-50'} "
> >
<img <img
src={model?.meta?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`} src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${model.id}`}
alt="modelfile profile" alt="modelfile profile"
class=" rounded-full w-full h-auto object-cover" class=" rounded-full w-full h-auto object-cover"
/> />
@ -420,6 +430,9 @@
copyLinkHandler={() => { copyLinkHandler={() => {
copyLinkHandler(model); copyLinkHandler(model);
}} }}
cloneHandler={() => {
cloneHandler(model);
}}
onClose={() => {}} onClose={() => {}}
> >
<button <button
@ -469,11 +482,11 @@
if (importFiles.length > 0) { if (importFiles.length > 0) {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = async (event) => { reader.onload = async (event) => {
modelsImportInProgress = true;
try { try {
const models = JSON.parse(String(event.target.result)); const models = JSON.parse(String(event.target.result));
modelsImportInProgress = true;
const res = await importModels(localStorage.token, models); const res = await importModels(localStorage.token, models);
modelsImportInProgress = false;
if (res) { if (res) {
toast.success($i18n.t('Models imported successfully')); toast.success($i18n.t('Models imported successfully'));
@ -482,9 +495,11 @@
toast.error($i18n.t('Failed to import models')); toast.error($i18n.t('Failed to import models'));
} }
} catch (e) { } catch (e) {
toast.error($i18n.t('Invalid JSON file')); toast.error(e?.detail ?? $i18n.t('Invalid JSON file'));
console.error(e); console.error(e);
} }
modelsImportInProgress = false;
}; };
reader.readAsText(importFiles[0]); reader.readAsText(importFiles[0]);
} }

View file

@ -7,18 +7,20 @@
import { models } from '$lib/stores'; import { models } from '$lib/stores';
import { deleteAllModels } from '$lib/apis/models'; import { deleteAllModels } from '$lib/apis/models';
import { getModelsConfig, setModelsConfig } from '$lib/apis/configs';
import Modal from '$lib/components/common/Modal.svelte'; import Modal from '$lib/components/common/Modal.svelte';
import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import ModelList from './ModelList.svelte'; import ModelList from './ModelList.svelte';
import { getModelsConfig, setModelsConfig } from '$lib/apis/configs';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import Minus from '$lib/components/icons/Minus.svelte'; import Minus from '$lib/components/icons/Minus.svelte';
import Plus from '$lib/components/icons/Plus.svelte'; import Plus from '$lib/components/icons/Plus.svelte';
import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import XMark from '$lib/components/icons/XMark.svelte'; import XMark from '$lib/components/icons/XMark.svelte';
import ModelSelector from './ModelSelector.svelte';
import Model from '../Evaluations/Model.svelte';
export let show = false; export let show = false;
export let initHandler = () => {}; export let initHandler = () => {};
@ -27,6 +29,10 @@
let selectedModelId = ''; let selectedModelId = '';
let defaultModelIds = []; let defaultModelIds = [];
let selectedPinnedModelId = '';
let defaultPinnedModelIds = [];
let modelIds = []; let modelIds = [];
let sortKey = ''; let sortKey = '';
@ -38,25 +44,6 @@
$: if (show) { $: if (show) {
init(); init();
} }
$: if (selectedModelId) {
onModelSelect();
}
const onModelSelect = () => {
if (selectedModelId === '') {
return;
}
if (defaultModelIds.includes(selectedModelId)) {
selectedModelId = '';
return;
}
defaultModelIds = [...defaultModelIds, selectedModelId];
selectedModelId = '';
};
const init = async () => { const init = async () => {
config = await getModelsConfig(localStorage.token); config = await getModelsConfig(localStorage.token);
@ -65,6 +52,13 @@
} else { } else {
defaultModelIds = []; defaultModelIds = [];
} }
if (config?.DEFAULT_PINNED_MODELS) {
defaultPinnedModelIds = (config?.DEFAULT_PINNED_MODELS).split(',').filter((id) => id);
} else {
defaultPinnedModelIds = [];
}
const modelOrderList = config.MODEL_ORDER_LIST || []; const modelOrderList = config.MODEL_ORDER_LIST || [];
const allModelIds = $models.map((model) => model.id); const allModelIds = $models.map((model) => model.id);
@ -86,6 +80,7 @@
const res = await setModelsConfig(localStorage.token, { const res = await setModelsConfig(localStorage.token, {
DEFAULT_MODELS: defaultModelIds.join(','), DEFAULT_MODELS: defaultModelIds.join(','),
DEFAULT_PINNED_MODELS: defaultPinnedModelIds.join(','),
MODEL_ORDER_LIST: modelIds MODEL_ORDER_LIST: modelIds
}); });
@ -191,59 +186,19 @@
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" /> <hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div> <ModelSelector
<div class="flex flex-col w-full"> title={$i18n.t('Default Models')}
<div class="mb-1 flex justify-between"> models={$models}
<div class="text-xs text-gray-500">{$i18n.t('Default Models')}</div> bind:modelIds={defaultModelIds}
</div> />
<div class="flex items-center -mr-1"> <hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<select
class="w-full py-1 text-sm rounded-lg bg-transparent {selectedModelId
? ''
: 'text-gray-500'} placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden"
bind:value={selectedModelId}
>
<option value="">{$i18n.t('Select a model')}</option>
{#each $models as model}
<option value={model.id} class="bg-gray-50 dark:bg-gray-700"
>{model.name}</option
>
{/each}
</select>
</div>
<!-- <hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" /> --> <ModelSelector
title={$i18n.t('Default Pinned Models')}
{#if defaultModelIds.length > 0} models={$models}
<div class="flex flex-col"> bind:modelIds={defaultPinnedModelIds}
{#each defaultModelIds as modelId, modelIdx} />
<div class=" flex gap-2 w-full justify-between items-center">
<div class=" text-sm flex-1 py-1 rounded-lg">
{$models.find((model) => model.id === modelId)?.name}
</div>
<div class="shrink-0">
<button
type="button"
on:click={() => {
defaultModelIds = defaultModelIds.filter(
(_, idx) => idx !== modelIdx
);
}}
>
<Minus strokeWidth="2" className="size-3.5" />
</button>
</div>
</div>
{/each}
</div>
{:else}
<div class="text-gray-500 text-xs text-center py-2">
{$i18n.t('No models selected')}
</div>
{/if}
</div>
</div>
<div class="flex justify-between pt-3 text-sm font-medium gap-1.5"> <div class="flex justify-between pt-3 text-sm font-medium gap-1.5">
<Tooltip content={$i18n.t('This will delete all models including custom models')}> <Tooltip content={$i18n.t('This will delete all models including custom models')}>

View file

@ -25,6 +25,7 @@
export let exportHandler: Function; export let exportHandler: Function;
export let hideHandler: Function; export let hideHandler: Function;
export let copyLinkHandler: Function; export let copyLinkHandler: Function;
export let cloneHandler: Function;
export let onClose: Function; export let onClose: Function;
@ -114,6 +115,17 @@
<div class="flex items-center">{$i18n.t('Copy Link')}</div> <div class="flex items-center">{$i18n.t('Copy Link')}</div>
</DropdownMenu.Item> </DropdownMenu.Item>
<DropdownMenu.Item
class="flex gap-2 items-center px-3 py-1.5 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
on:click={() => {
cloneHandler();
}}
>
<DocumentDuplicate />
<div class="flex items-center">{$i18n.t('Clone')}</div>
</DropdownMenu.Item>
<DropdownMenu.Item <DropdownMenu.Item
class="flex gap-2 items-center px-3 py-1.5 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md" class="flex gap-2 items-center px-3 py-1.5 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
on:click={() => { on:click={() => {

View file

@ -0,0 +1,70 @@
<script lang="ts">
import { getContext } from 'svelte';
const i18n = getContext('i18n');
import Minus from '$lib/components/icons/Minus.svelte';
export let title = '';
export let models = [];
export let modelIds = [];
let selectedModelId = '';
</script>
<div>
<div class="flex flex-col w-full">
<div class="mb-1 flex justify-between">
<div class="text-xs text-gray-500">{title}</div>
</div>
<div class="flex items-center -mr-1">
<select
class="w-full py-1 text-sm rounded-lg bg-transparent {selectedModelId
? ''
: 'text-gray-500'} placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden"
bind:value={selectedModelId}
on:change={() => {
if (selectedModelId && !modelIds.includes(selectedModelId)) {
modelIds = [...modelIds, selectedModelId];
}
selectedModelId = '';
}}
>
<option value="">{$i18n.t('Select a model')}</option>
{#each models as model}
{#if !modelIds.includes(model.id)}
<option value={model.id} class="bg-gray-50 dark:bg-gray-700">{model.name}</option>
{/if}
{/each}
</select>
</div>
<!-- <hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" /> -->
{#if modelIds.length > 0}
<div class="flex flex-col">
{#each modelIds as modelId, modelIdx}
<div class=" flex gap-2 w-full justify-between items-center">
<div class=" text-sm flex-1 py-1 rounded-lg">
{models.find((model) => model.id === modelId)?.name}
</div>
<div class="shrink-0">
<button
type="button"
on:click={() => {
modelIds = modelIds.filter((_, idx) => idx !== modelIdx);
}}
>
<Minus strokeWidth="2" className="size-3.5" />
</button>
</div>
</div>
{/each}
</div>
{:else}
<div class="text-gray-500 text-xs text-center py-2">
{$i18n.t('No models selected')}
</div>
{/if}
</div>
</div>

View file

@ -224,7 +224,7 @@
<div class="overflow-y-scroll scrollbar-hidden h-full"> <div class="overflow-y-scroll scrollbar-hidden h-full">
{#if PIPELINES_LIST !== null} {#if PIPELINES_LIST !== null}
<div class="flex w-full justify-between mb-2"> <div class="flex w-full justify-between mb-2">
<div class=" self-center text-sm font-semibold"> <div class=" self-center text-sm font-medium">
{$i18n.t('Manage Pipelines')} {$i18n.t('Manage Pipelines')}
</div> </div>
</div> </div>
@ -410,7 +410,7 @@
</div> </div>
<div class="mt-2 text-xs text-gray-500"> <div class="mt-2 text-xs text-gray-500">
<span class=" font-semibold dark:text-gray-200">{$i18n.t('Warning:')}</span> <span class=" font-medium dark:text-gray-200">{$i18n.t('Warning:')}</span>
{$i18n.t('Pipelines are a plugin system with arbitrary code execution —')} {$i18n.t('Pipelines are a plugin system with arbitrary code execution —')}
<span class=" font-medium dark:text-gray-400" <span class=" font-medium dark:text-gray-400"
>{$i18n.t("don't fetch random pipelines from sources you don't trust.")}</span >{$i18n.t("don't fetch random pipelines from sources you don't trust.")}</span
@ -423,7 +423,7 @@
{#if pipelines !== null} {#if pipelines !== null}
{#if pipelines.length > 0} {#if pipelines.length > 0}
<div class="flex w-full justify-between mb-2"> <div class="flex w-full justify-between mb-2">
<div class=" self-center text-sm font-semibold"> <div class=" self-center text-sm font-medium">
{$i18n.t('Pipelines Valves')} {$i18n.t('Pipelines Valves')}
</div> </div>
</div> </div>

View file

@ -150,6 +150,26 @@
</div> </div>
</div> </div>
{:else if webConfig.WEB_SEARCH_ENGINE === 'perplexity_search'} {:else if webConfig.WEB_SEARCH_ENGINE === 'perplexity_search'}
<div class="mb-2.5 flex w-full flex-col">
<div>
<div class=" self-center text-xs font-medium mb-1">
{$i18n.t('Perplexity Search API URL')}
</div>
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden"
type="text"
placeholder={$i18n.t('Enter Perplexity Search API URL')}
bind:value={webConfig.PERPLEXITY_SEARCH_API_URL}
autocomplete="off"
/>
</div>
</div>
</div>
</div>
<div class="mb-2.5 flex w-full flex-col"> <div class="mb-2.5 flex w-full flex-col">
<div> <div>
<div class=" self-center text-xs font-medium mb-1"> <div class=" self-center text-xs font-medium mb-1">
@ -664,7 +684,7 @@
<input <input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden" class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden"
placeholder={$i18n.t( placeholder={$i18n.t(
'Enter domains separated by commas (e.g., example.com,site.org)' 'Enter domains separated by commas (e.g., example.com,site.org,!excludedsite.com)'
)} )}
bind:value={webConfig.WEB_SEARCH_DOMAIN_FILTER_LIST} bind:value={webConfig.WEB_SEARCH_DOMAIN_FILTER_LIST}
/> />

View file

@ -33,9 +33,6 @@
let loaded = false; let loaded = false;
let users = [];
let total = 0;
let groups = []; let groups = [];
let filteredGroups; let filteredGroups;
@ -93,16 +90,6 @@
return; return;
} }
const res = await getAllUsers(localStorage.token).catch((error) => {
toast.error(`${error}`);
return null;
});
if (res) {
users = res.users;
total = res.total;
}
defaultPermissions = await getUserDefaultPermissions(localStorage.token); defaultPermissions = await getUserDefaultPermissions(localStorage.token);
await setGroups(); await setGroups();
loaded = true; loaded = true;
@ -118,11 +105,14 @@
/> />
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between"> <div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
<div class="flex md:self-center text-lg font-medium px-0.5"> <div class="flex items-center md:self-center text-xl font-medium px-0.5 gap-2 shrink-0">
{$i18n.t('Groups')} <div>
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" /> {$i18n.t('Groups')}
</div>
<span class="text-lg font-medium text-gray-500 dark:text-gray-300">{groups.length}</span> <div class="text-lg font-medium text-gray-500 dark:text-gray-500">
{groups.length}
</div>
</div> </div>
<div class="flex gap-1"> <div class="flex gap-1">
@ -179,7 +169,7 @@
</div> </div>
{:else} {:else}
<div> <div>
<div class=" flex items-center gap-3 justify-between text-xs uppercase px-1 font-semibold"> <div class=" flex items-center gap-3 justify-between text-xs uppercase px-1 font-medium">
<div class="w-full basis-3/5">{$i18n.t('Group')}</div> <div class="w-full basis-3/5">{$i18n.t('Group')}</div>
<div class="w-full basis-2/5 text-right">{$i18n.t('Users')}</div> <div class="w-full basis-2/5 text-right">{$i18n.t('Users')}</div>
@ -189,7 +179,7 @@
{#each filteredGroups as group} {#each filteredGroups as group}
<div class="my-2"> <div class="my-2">
<GroupItem {group} {users} {setGroups} {defaultPermissions} /> <GroupItem {group} {setGroups} {defaultPermissions} />
</div> </div>
{/each} {/each}
</div> </div>

View file

@ -5,7 +5,7 @@
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import Modal from '$lib/components/common/Modal.svelte'; import Modal from '$lib/components/common/Modal.svelte';
import Display from './Display.svelte'; import General from './General.svelte';
import Permissions from './Permissions.svelte'; import Permissions from './Permissions.svelte';
import Users from './Users.svelte'; import Users from './Users.svelte';
import UserPlusSolid from '$lib/components/icons/UserPlusSolid.svelte'; import UserPlusSolid from '$lib/components/icons/UserPlusSolid.svelte';
@ -19,7 +19,6 @@
export let show = false; export let show = false;
export let edit = false; export let edit = false;
export let users = [];
export let group = null; export let group = null;
export let defaultPermissions = {}; export let defaultPermissions = {};
@ -31,21 +30,36 @@
let loading = false; let loading = false;
let showDeleteConfirmDialog = false; let showDeleteConfirmDialog = false;
let userCount = 0;
export let name = ''; export let name = '';
export let description = ''; export let description = '';
export let data = {};
export let permissions = { export let permissions = {
workspace: { workspace: {
models: false, models: false,
knowledge: false, knowledge: false,
prompts: false, prompts: false,
tools: false tools: false,
models_import: false,
models_export: false,
prompts_import: false,
prompts_export: false,
tools_import: false,
tools_export: false
}, },
sharing: { sharing: {
models: false,
public_models: false, public_models: false,
knowledge: false,
public_knowledge: false, public_knowledge: false,
prompts: false,
public_prompts: false, public_prompts: false,
public_tools: false tools: false,
public_tools: false,
notes: false,
public_notes: false
}, },
chat: { chat: {
controls: true, controls: true,
@ -69,13 +83,14 @@
temporary_enforced: false temporary_enforced: false
}, },
features: { features: {
api_keys: false,
direct_tool_servers: false, direct_tool_servers: false,
web_search: true, web_search: true,
image_generation: true, image_generation: true,
code_interpreter: true code_interpreter: true,
notes: true
} }
}; };
export let userIds = [];
const submitHandler = async () => { const submitHandler = async () => {
loading = true; loading = true;
@ -83,8 +98,8 @@
const group = { const group = {
name, name,
description, description,
permissions, data,
user_ids: userIds permissions
}; };
await onSubmit(group); await onSubmit(group);
@ -98,8 +113,9 @@
name = group.name; name = group.name;
description = group.description; description = group.description;
permissions = group?.permissions ?? {}; permissions = group?.permissions ?? {};
data = group?.data ?? {};
userIds = group?.user_ids ?? []; userCount = group?.member_count ?? 0;
} }
}; };
@ -121,7 +137,7 @@
}} }}
/> />
<Modal size="md" bind:show> <Modal size="lg" bind:show>
<div> <div>
<div class=" flex justify-between dark:text-gray-100 px-5 pt-4 mb-1.5"> <div class=" flex justify-between dark:text-gray-100 px-5 pt-4 mb-1.5">
<div class=" text-lg font-medium self-center font-primary"> <div class=" text-lg font-medium self-center font-primary">
@ -220,20 +236,48 @@
<div class=" self-center mr-2"> <div class=" self-center mr-2">
<UserPlusSolid /> <UserPlusSolid />
</div> </div>
<div class=" self-center">{$i18n.t('Users')} ({userIds.length})</div> <div class=" self-center">{$i18n.t('Users')}</div>
</button> </button>
{/if} {/if}
</div> </div>
<div <div class="flex-1 mt-1 lg:mt-1 lg:h-[30rem] lg:max-h-[30rem] flex flex-col">
class="flex-1 mt-1 lg:mt-1 lg:h-[22rem] lg:max-h-[22rem] overflow-y-auto scrollbar-hidden" <div class="w-full h-full overflow-y-auto scrollbar-hidden">
> {#if selectedTab == 'general'}
{#if selectedTab == 'general'} <General
<Display bind:name bind:description /> bind:name
{:else if selectedTab == 'permissions'} bind:description
<Permissions bind:permissions {defaultPermissions} /> bind:data
{:else if selectedTab == 'users'} {edit}
<Users bind:userIds {users} /> onDelete={() => {
showDeleteConfirmDialog = true;
}}
/>
{:else if selectedTab == 'permissions'}
<Permissions bind:permissions {defaultPermissions} />
{:else if selectedTab == 'users'}
<Users bind:userCount groupId={group?.id} />
{/if}
</div>
{#if ['general', 'permissions'].includes(selectedTab)}
<div class="flex justify-end pt-3 text-sm font-medium gap-1.5">
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed'
: ''}"
type="submit"
disabled={loading}
>
{$i18n.t('Save')}
{#if loading}
<div class="ml-2 self-center">
<Spinner />
</div>
{/if}
</button>
</div>
{/if} {/if}
</div> </div>
</div> </div>
@ -286,38 +330,6 @@
</button> </button>
{/if} {/if}
</div> --> </div> -->
<div class="flex justify-between pt-3 text-sm font-medium gap-1.5">
{#if edit}
<button
class="px-3.5 py-1.5 text-sm font-medium dark:bg-black dark:hover:bg-gray-900 dark:text-white bg-white text-black hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center"
type="button"
on:click={() => {
showDeleteConfirmDialog = true;
}}
>
{$i18n.t('Delete')}
</button>
{:else}
<div></div>
{/if}
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed'
: ''}"
type="submit"
disabled={loading}
>
{$i18n.t('Save')}
{#if loading}
<div class="ml-2 self-center">
<Spinner />
</div>
{/if}
</button>
</div>
</form> </form>
</div> </div>
</div> </div>

View file

@ -2,12 +2,17 @@
import { getContext } from 'svelte'; import { getContext } from 'svelte';
import Textarea from '$lib/components/common/Textarea.svelte'; import Textarea from '$lib/components/common/Textarea.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Switch from '$lib/components/common/Switch.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
export let name = ''; export let name = '';
export let color = ''; export let color = '';
export let description = ''; export let description = '';
export let data = {};
export let edit = false;
export let onDelete: Function = () => {};
</script> </script>
<div class="flex gap-2"> <div class="flex gap-2">
@ -59,3 +64,47 @@
/> />
</div> </div>
</div> </div>
<hr class="border-gray-50 dark:border-gray-850 my-1" />
<div class="flex flex-col w-full mt-2">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Setting')}</div>
<div>
<div class=" flex w-full justify-between">
<div class=" self-center text-xs">
{$i18n.t('Allow Group Sharing')}
</div>
<div class="flex items-center gap-2 p-1">
<Switch
tooltip={true}
state={data?.config?.share ?? true}
on:change={(e) => {
if (data?.config?.share) {
data.config.share = e.detail;
} else {
data.config = { ...(data?.config ?? {}), share: e.detail };
}
}}
/>
</div>
</div>
</div>
</div>
{#if edit}
<div class="flex flex-col w-full mt-2">
<div class=" mb-0.5 text-xs text-gray-500">{$i18n.t('Actions')}</div>
<div class="flex-1">
<button
class="text-xs bg-transparent hover:underline cursor-pointer"
type="button"
on:click={() => onDelete()}
>
{$i18n.t('Delete')}
</button>
</div>
</div>
{/if}

View file

@ -10,9 +10,8 @@
import Pencil from '$lib/components/icons/Pencil.svelte'; import Pencil from '$lib/components/icons/Pencil.svelte';
import User from '$lib/components/icons/User.svelte'; import User from '$lib/components/icons/User.svelte';
import UserCircleSolid from '$lib/components/icons/UserCircleSolid.svelte'; import UserCircleSolid from '$lib/components/icons/UserCircleSolid.svelte';
import GroupModal from './EditGroupModal.svelte'; import EditGroupModal from './EditGroupModal.svelte';
export let users = [];
export let group = { export let group = {
name: 'Admins', name: 'Admins',
user_ids: [1, 2, 3] user_ids: [1, 2, 3]
@ -55,10 +54,9 @@
}); });
</script> </script>
<GroupModal <EditGroupModal
bind:show={showEdit} bind:show={showEdit}
edit edit
{users}
{group} {group}
{defaultPermissions} {defaultPermissions}
onSubmit={updateHandler} onSubmit={updateHandler}
@ -81,7 +79,7 @@
</div> </div>
<div class="flex items-center gap-1.5 w-fit font-medium text-right justify-end"> <div class="flex items-center gap-1.5 w-fit font-medium text-right justify-end">
{group.user_ids.length} {group?.member_count}
<div> <div>
<User className="size-3.5" /> <User className="size-3.5" />

View file

@ -11,13 +11,24 @@
models: false, models: false,
knowledge: false, knowledge: false,
prompts: false, prompts: false,
tools: false tools: false,
models_import: false,
models_export: false,
prompts_import: false,
prompts_export: false,
tools_import: false,
tools_export: false
}, },
sharing: { sharing: {
models: false,
public_models: false, public_models: false,
knowledge: false,
public_knowledge: false, public_knowledge: false,
prompts: false,
public_prompts: false, public_prompts: false,
tools: false,
public_tools: false, public_tools: false,
notes: false,
public_notes: false public_notes: false
}, },
chat: { chat: {
@ -42,6 +53,7 @@
temporary_enforced: false temporary_enforced: false
}, },
features: { features: {
api_keys: false,
direct_tool_servers: false, direct_tool_servers: false,
web_search: true, web_search: true,
image_generation: true, image_generation: true,
@ -90,8 +102,24 @@
</div> </div>
<Switch bind:state={permissions.workspace.models} /> <Switch bind:state={permissions.workspace.models} />
</div> </div>
{#if defaultPermissions?.workspace?.models && !permissions.workspace.models}
<div> {#if permissions.workspace.models}
<div class="ml-2 flex flex-col gap-2 pt-0.5 pb-1">
<div class="flex w-full justify-between">
<div class="self-center text-xs">
{$i18n.t('Import Models')}
</div>
<Switch bind:state={permissions.workspace.models_import} />
</div>
<div class="flex w-full justify-between">
<div class="self-center text-xs">
{$i18n.t('Export Models')}
</div>
<Switch bind:state={permissions.workspace.models_export} />
</div>
</div>
{:else if defaultPermissions?.workspace?.models}
<div class="pb-0.5">
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
</div> </div>
@ -122,8 +150,24 @@
</div> </div>
<Switch bind:state={permissions.workspace.prompts} /> <Switch bind:state={permissions.workspace.prompts} />
</div> </div>
{#if defaultPermissions?.workspace?.prompts && !permissions.workspace.prompts}
<div> {#if permissions.workspace.prompts}
<div class="ml-2 flex flex-col gap-2 pt-0.5 pb-1">
<div class="flex w-full justify-between">
<div class="self-center text-xs">
{$i18n.t('Import Prompts')}
</div>
<Switch bind:state={permissions.workspace.prompts_import} />
</div>
<div class="flex w-full justify-between">
<div class="self-center text-xs">
{$i18n.t('Export Prompts')}
</div>
<Switch bind:state={permissions.workspace.prompts_export} />
</div>
</div>
{:else if defaultPermissions?.workspace?.prompts}
<div class="pb-0.5">
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
</div> </div>
@ -144,8 +188,24 @@
</div> </div>
<Switch bind:state={permissions.workspace.tools} /> <Switch bind:state={permissions.workspace.tools} />
</Tooltip> </Tooltip>
{#if defaultPermissions?.workspace?.tools && !permissions.workspace.tools}
<div> {#if permissions.workspace.tools}
<div class="ml-2 flex flex-col gap-2 pt-0.5 pb-1">
<div class="flex w-full justify-between">
<div class="self-center text-xs">
{$i18n.t('Import Tools')}
</div>
<Switch bind:state={permissions.workspace.tools_import} />
</div>
<div class="flex w-full justify-between">
<div class="self-center text-xs">
{$i18n.t('Export Tools')}
</div>
<Switch bind:state={permissions.workspace.tools_export} />
</div>
</div>
{:else if defaultPermissions?.workspace?.tools}
<div class="pb-0.5">
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
</div> </div>
@ -162,11 +222,11 @@
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1"> <div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Models Public Sharing')} {$i18n.t('Models Sharing')}
</div> </div>
<Switch bind:state={permissions.sharing.public_models} /> <Switch bind:state={permissions.sharing.models} />
</div> </div>
{#if defaultPermissions?.sharing?.public_models && !permissions.sharing.public_models} {#if defaultPermissions?.sharing?.models && !permissions.sharing.models}
<div> <div>
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
@ -175,14 +235,32 @@
{/if} {/if}
</div> </div>
{#if permissions.sharing.models}
<div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium">
{$i18n.t('Models Public Sharing')}
</div>
<Switch bind:state={permissions.sharing.public_models} />
</div>
{#if defaultPermissions?.sharing?.public_models && !permissions.sharing.public_models}
<div>
<div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')}
</div>
</div>
{/if}
</div>
{/if}
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1"> <div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Knowledge Public Sharing')} {$i18n.t('Knowledge Sharing')}
</div> </div>
<Switch bind:state={permissions.sharing.public_knowledge} /> <Switch bind:state={permissions.sharing.knowledge} />
</div> </div>
{#if defaultPermissions?.sharing?.public_knowledge && !permissions.sharing.public_knowledge} {#if defaultPermissions?.sharing?.knowledge && !permissions.sharing.knowledge}
<div> <div>
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
@ -191,14 +269,32 @@
{/if} {/if}
</div> </div>
{#if permissions.sharing.knowledge}
<div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium">
{$i18n.t('Knowledge Public Sharing')}
</div>
<Switch bind:state={permissions.sharing.public_knowledge} />
</div>
{#if defaultPermissions?.sharing?.public_knowledge && !permissions.sharing.public_knowledge}
<div>
<div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')}
</div>
</div>
{/if}
</div>
{/if}
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1"> <div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Prompts Public Sharing')} {$i18n.t('Prompts Sharing')}
</div> </div>
<Switch bind:state={permissions.sharing.public_prompts} /> <Switch bind:state={permissions.sharing.prompts} />
</div> </div>
{#if defaultPermissions?.sharing?.public_prompts && !permissions.sharing.public_prompts} {#if defaultPermissions?.sharing?.prompts && !permissions.sharing.prompts}
<div> <div>
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
@ -207,14 +303,32 @@
{/if} {/if}
</div> </div>
{#if permissions.sharing.prompts}
<div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium">
{$i18n.t('Prompts Public Sharing')}
</div>
<Switch bind:state={permissions.sharing.public_prompts} />
</div>
{#if defaultPermissions?.sharing?.public_prompts && !permissions.sharing.public_prompts}
<div>
<div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')}
</div>
</div>
{/if}
</div>
{/if}
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1"> <div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Tools Public Sharing')} {$i18n.t('Tools Sharing')}
</div> </div>
<Switch bind:state={permissions.sharing.public_tools} /> <Switch bind:state={permissions.sharing.tools} />
</div> </div>
{#if defaultPermissions?.sharing?.public_tools && !permissions.sharing.public_tools} {#if defaultPermissions?.sharing?.tools && !permissions.sharing.tools}
<div> <div>
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
@ -223,14 +337,32 @@
{/if} {/if}
</div> </div>
{#if permissions.sharing.tools}
<div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium">
{$i18n.t('Tools Public Sharing')}
</div>
<Switch bind:state={permissions.sharing.public_tools} />
</div>
{#if defaultPermissions?.sharing?.public_tools && !permissions.sharing.public_tools}
<div>
<div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')}
</div>
</div>
{/if}
</div>
{/if}
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1"> <div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Notes Public Sharing')} {$i18n.t('Notes Sharing')}
</div> </div>
<Switch bind:state={permissions.sharing.public_notes} /> <Switch bind:state={permissions.sharing.notes} />
</div> </div>
{#if defaultPermissions?.sharing?.public_notes && !permissions.sharing.public_notes} {#if defaultPermissions?.sharing?.notes && !permissions.sharing.notes}
<div> <div>
<div class="text-xs text-gray-500"> <div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')} {$i18n.t('This is a default user permission and will remain enabled.')}
@ -238,6 +370,24 @@
</div> </div>
{/if} {/if}
</div> </div>
{#if permissions.sharing.notes}
<div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium">
{$i18n.t('Notes Public Sharing')}
</div>
<Switch bind:state={permissions.sharing.public_notes} />
</div>
{#if defaultPermissions?.sharing?.public_notes && !permissions.sharing.public_notes}
<div>
<div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')}
</div>
</div>
{/if}
</div>
{/if}
</div> </div>
<hr class=" border-gray-100 dark:border-gray-850" /> <hr class=" border-gray-100 dark:border-gray-850" />
@ -559,6 +709,22 @@
<div> <div>
<div class=" mb-2 text-sm font-medium">{$i18n.t('Features Permissions')}</div> <div class=" mb-2 text-sm font-medium">{$i18n.t('Features Permissions')}</div>
<div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium">
{$i18n.t('API Keys')}
</div>
<Switch bind:state={permissions.features.api_keys} />
</div>
{#if defaultPermissions?.features?.api_keys && !permissions.features.api_keys}
<div>
<div class="text-xs text-gray-500">
{$i18n.t('This is a default user permission and will remain enabled.')}
</div>
</div>
{/if}
</div>
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex w-full justify-between my-1"> <div class="flex w-full justify-between my-1">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">

View file

@ -2,50 +2,93 @@
import { getContext } from 'svelte'; import { getContext } from 'svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
import localizedFormat from 'dayjs/plugin/localizedFormat';
dayjs.extend(relativeTime);
dayjs.extend(localizedFormat);
import { getUsers } from '$lib/apis/users';
import { toast } from 'svelte-sonner';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Plus from '$lib/components/icons/Plus.svelte';
import { WEBUI_BASE_URL } from '$lib/constants';
import Checkbox from '$lib/components/common/Checkbox.svelte'; import Checkbox from '$lib/components/common/Checkbox.svelte';
import Badge from '$lib/components/common/Badge.svelte'; import Badge from '$lib/components/common/Badge.svelte';
import Search from '$lib/components/icons/Search.svelte'; import Search from '$lib/components/icons/Search.svelte';
import Pagination from '$lib/components/common/Pagination.svelte';
import { addUserToGroup, removeUserFromGroup } from '$lib/apis/groups';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
import { WEBUI_API_BASE_URL } from '$lib/constants';
export let users = []; export let groupId: string;
export let userIds = []; export let userCount = 0;
let filteredUsers = []; let users = [];
let total = 0;
$: filteredUsers = users
.filter((user) => {
if (query === '') {
return true;
}
return (
user.name.toLowerCase().includes(query.toLowerCase()) ||
user.email.toLowerCase().includes(query.toLowerCase())
);
})
.sort((a, b) => {
const aUserIndex = userIds.indexOf(a.id);
const bUserIndex = userIds.indexOf(b.id);
// Compare based on userIds or fall back to alphabetical order
if (aUserIndex !== -1 && bUserIndex === -1) return -1; // 'a' has valid userId -> prioritize
if (bUserIndex !== -1 && aUserIndex === -1) return 1; // 'b' has valid userId -> prioritize
// Both a and b are either in the userIds array or not, so we'll sort them by their indices
if (aUserIndex !== -1 && bUserIndex !== -1) return aUserIndex - bUserIndex;
// If both are not in the userIds, fallback to alphabetical sorting by name
return a.name.localeCompare(b.name);
});
let query = ''; let query = '';
let orderBy = `group_id:${groupId}`; // default sort key
let direction = 'desc'; // default sort order
let page = 1;
const setSortKey = (key) => {
if (orderBy === key) {
direction = direction === 'asc' ? 'desc' : 'asc';
} else {
orderBy = key;
direction = 'asc';
}
};
const getUserList = async () => {
try {
const res = await getUsers(localStorage.token, query, orderBy, direction, page).catch(
(error) => {
toast.error(`${error}`);
return null;
}
);
if (res) {
users = res.users;
total = res.total;
}
} catch (err) {
console.error(err);
}
};
const toggleMember = async (userId, state) => {
if (state === 'checked') {
await addUserToGroup(localStorage.token, groupId, [userId]).catch((error) => {
toast.error(`${error}`);
return null;
});
} else {
await removeUserFromGroup(localStorage.token, groupId, [userId]).catch((error) => {
toast.error(`${error}`);
return null;
});
}
page = 1;
getUserList();
};
$: if (page !== null && query !== null && orderBy !== null && direction !== null) {
getUserList();
}
$: if (query) {
page = 1;
}
</script> </script>
<div> <div class=" max-h-full h-full w-full flex flex-col overflow-y-hidden">
<div class="flex w-full"> <div class="w-full h-fit mb-1.5">
<div class="flex flex-1"> <div class="flex flex-1 h-fit">
<div class=" self-center mr-3"> <div class=" self-center mr-3">
<Search /> <Search />
</div> </div>
@ -57,42 +100,163 @@
</div> </div>
</div> </div>
<div class="mt-3 scrollbar-hidden"> {#if users.length > 0}
<div class="flex flex-col gap-2.5"> <div class="scrollbar-hidden relative whitespace-nowrap overflow-x-auto max-w-full">
{#if filteredUsers.length > 0} <table
{#each filteredUsers as user, userIdx (user.id)} class="w-full text-sm text-left text-gray-500 dark:text-gray-400 table-auto max-w-full"
<div class="flex flex-row items-center gap-3 w-full text-sm"> >
<div class="flex items-center"> <thead class="text-xs text-gray-800 uppercase bg-transparent dark:text-gray-200">
<Checkbox <tr class=" border-b-[1.5px] border-gray-50 dark:border-gray-850">
state={userIds.includes(user.id) ? 'checked' : 'unchecked'} <th
on:change={(e) => { scope="col"
if (e.detail === 'checked') { class="px-2.5 py-2 cursor-pointer text-left w-8"
userIds = [...userIds, user.id]; on:click={() => setSortKey(`group_id:${groupId}`)}
} else { >
userIds = userIds.filter((id) => id !== user.id); <div class="flex gap-1.5 items-center">
} {$i18n.t('MBR')}
}}
/>
</div>
<div class="flex w-full items-center justify-between overflow-hidden"> {#if orderBy === `group_id:${groupId}`}
<Tooltip content={user.email} placement="top-start"> <span class="font-normal"
<div class="flex"> >{#if direction === 'asc'}
<div class=" font-medium self-center truncate">{user.name}</div> <ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th
scope="col"
class="px-2.5 py-2 cursor-pointer select-none"
on:click={() => setSortKey('role')}
>
<div class="flex gap-1.5 items-center">
{$i18n.t('Role')}
{#if orderBy === 'role'}
<span class="font-normal"
>{#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th
scope="col"
class="px-2.5 py-2 cursor-pointer select-none"
on:click={() => setSortKey('name')}
>
<div class="flex gap-1.5 items-center">
{$i18n.t('Name')}
{#if orderBy === 'name'}
<span class="font-normal"
>{#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
<th
scope="col"
class="px-2.5 py-2 cursor-pointer select-none"
on:click={() => setSortKey('last_active_at')}
>
<div class="flex gap-1.5 items-center">
{$i18n.t('Last Active')}
{#if orderBy === 'last_active_at'}
<span class="font-normal"
>{#if direction === 'asc'}
<ChevronUp className="size-2" />
{:else}
<ChevronDown className="size-2" />
{/if}
</span>
{:else}
<span class="invisible">
<ChevronUp className="size-2" />
</span>
{/if}
</div>
</th>
</tr>
</thead>
<tbody class="">
{#each users as user, userIdx}
<tr class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs">
<td class=" px-3 py-1 w-8">
<div class="flex w-full justify-center">
<Checkbox
state={(user?.group_ids ?? []).includes(groupId) ? 'checked' : 'unchecked'}
on:change={(e) => {
toggleMember(user.id, e.detail);
}}
/>
</div> </div>
</Tooltip> </td>
<td class="px-3 py-1 min-w-[7rem] w-28">
<div class=" translate-y-0.5">
<Badge
type={user.role === 'admin'
? 'info'
: user.role === 'user'
? 'success'
: 'muted'}
content={$i18n.t(user.role)}
/>
</div>
</td>
<td class="px-3 py-1 font-medium text-gray-900 dark:text-white max-w-48">
<Tooltip content={user.email} placement="top-start">
<div class="flex items-center">
<img
class="rounded-full w-6 h-6 object-cover mr-2.5 flex-shrink-0"
src={`${WEBUI_API_BASE_URL}/users/${user.id}/profile/image`}
alt="user"
/>
{#if userIds.includes(user.id)} <div class="font-medium truncate">{user.name}</div>
<Badge type="success" content="member" /> </div>
{/if} </Tooltip>
</div> </td>
</div>
{/each} <td class=" px-3 py-1">
{:else} {dayjs(user.last_active_at * 1000).fromNow()}
<div class="text-gray-500 text-xs text-center py-2 px-10"> </td>
{$i18n.t('No users were found.')} </tr>
</div> {/each}
{/if} </tbody>
</table>
</div> </div>
</div> {:else}
<div class="text-gray-500 text-xs text-center py-2 px-10">
{$i18n.t('No users were found.')}
</div>
{/if}
{#if total > 30}
<Pagination bind:page count={total} perPage={30} />
{/if}
</div> </div>

Some files were not shown because too many files have changed in this diff Show more