mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 20:35:19 +00:00
merge version v0.6.25
This commit is contained in:
commit
876db8ec7f
357 changed files with 32892 additions and 11907 deletions
4
.github/workflows/format-build-frontend.yaml
vendored
4
.github/workflows/format-build-frontend.yaml
vendored
|
|
@ -32,7 +32,7 @@ jobs:
|
|||
node-version: '22'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: npm install
|
||||
run: npm install --force
|
||||
|
||||
- name: Format Frontend
|
||||
run: npm run format
|
||||
|
|
@ -59,7 +59,7 @@ jobs:
|
|||
node-version: '22'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: npm ci
|
||||
run: npm ci --force
|
||||
|
||||
- name: Run vitest
|
||||
run: npm run test:frontend
|
||||
|
|
|
|||
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
|||
x.py
|
||||
yarn.lock
|
||||
.DS_Store
|
||||
node_modules
|
||||
/build
|
||||
|
|
@ -12,7 +14,8 @@ vite.config.ts.timestamp-*
|
|||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
.nvmrc
|
||||
CLAUDE.md
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
|
|
|
|||
298
CHANGELOG.md
298
CHANGELOG.md
|
|
@ -5,6 +5,304 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.6.25] - 2025-08-22
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🖼️ **Image Generation Reliability Restored**: Fixed a key issue causing image generation failures.
|
||||
- 🏆 **Reranking Functionality Restored**: Resolved errors with rerank feature.
|
||||
|
||||
## [0.6.24] - 2025-08-21
|
||||
|
||||
### Added
|
||||
|
||||
- ♿ **High Contrast Mode in Chat Messages**: Implemented enhanced High Contrast Mode support for chat messages, making text and important details easier to read and improving accessibility for users with visual preferences or requirements.
|
||||
- 🌎 **Localization & Internationalization Improvements**: Enhanced and expanded translations for a more natural and professional user experience for speakers of these languages across the entire interface.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🖼️ **ComfyUI Image Generation Restored**: Fixed a critical bug where ComfyUI-based image generation was not functioning, ensuring users can once again effortlessly create and interact with AI-generated visuals in their workflows.
|
||||
- 🛠️ **Tool Server Loading and Visibility Restored**: Resolved an issue where connected tool servers were not loading or visible, restoring seamless integration and uninterrupted access to all external and custom tools directly within the platform.
|
||||
- 🛡️ **Redis User Session Reliability**: Fixed a problem affecting the saving of user sessions in Redis, ensuring reliable login sessions, stable authentication, and secure multi-user environments.
|
||||
|
||||
## [0.6.23] - 2025-08-21
|
||||
|
||||
### Added
|
||||
|
||||
- ⚡ **Asynchronous Chat Payload Processing**: Refactored the chat completion pipeline to return a response immediately for streaming requests involving web search or tool calls. This enables users to stop ongoing generations promptly and preventing network timeouts during lengthy preprocessing phases, thus significantly improving user experience and responsiveness.
|
||||
- 📁 **Asynchronous File Upload with Polling**: Implemented an asynchronous file upload process with frontend polling to resolve gateway timeouts and improve reliability when uploading large files. This ensures that even lengthy file processing, such as embedding or transcription, does not block the user interface or lead to connection timeouts, providing a smoother experience for all file operations.
|
||||
- 📈 **Database Performance Indexes and Migration Script**: Introduced new database indexes on the "chat", "tag", and "function" tables to significantly enhance query performance for SQLite and PostgreSQL installations. For existing deployments, a new Alembic migration script is included to seamlessly apply these indexes, ensuring faster filtering and sorting operations across the platform.
|
||||
- ✨ **Enhanced Database Performance Options**: Introduced new configurable options to significantly improve database performance, especially for SQLite. This includes "DATABASE_ENABLE_SQLITE_WAL" to enable SQLite WAL (Write-Ahead Logging) mode for concurrent operations, and "DATABASE_DEDUPLICATE_INTERVAL" which, in conjunction with a new deduplication mechanism, reduces redundant updates to "user.last_active_at", minimizing write conflicts across all database types.
|
||||
- 💾 **Save Temporary Chats Button**: Introduced a new 'Save Chat' button for conversations initiated in temporary mode. This allows users to permanently save valuable temporary conversations to their chat history, providing greater flexibility and ensuring important discussions are not lost.
|
||||
- 📂 **Chat Movement Options in Menu**: Added the ability to move chats directly to folders from the chat menu. This enhances chat organization and allows users to manage their conversations more efficiently by relocating them between folders with ease.
|
||||
- 💬 **Language-Aware Follow-Up Suggestions**: Enhanced the AI's follow-up question generation to dynamically adapt to the primary language of the current chat. Follow-up prompts will now be suggested in the same language the user and AI are conversing in, ensuring more natural and contextually relevant interactions.
|
||||
- 👤 **Expanded User Profile Details**: Introduced new user profile fields including username, bio, gender, and date of birth, allowing for more comprehensive user customization and information management. This enhancement includes corresponding updates to the database schema, API, and user interface for seamless integration.
|
||||
- 👥 **Direct Navigation to User Groups from User Edit**: Enhanced the user edit modal to include a direct link to the associated user group. This allows administrators to quickly navigate from a user's profile to their group settings, streamlining user and group management workflows.
|
||||
- 🔧 **Enhanced External Tool Server Compatibility**: Improved handling of responses from external tool servers, allowing both the backend and frontend to process plain text content in addition to JSON, ensuring greater flexibility and integration with diverse tool outputs.
|
||||
- 🗣️ **Enhanced Audio Transcription Language Fallback and Deepgram Support**: Implemented a robust language fallback mechanism for both OpenAI and Deepgram Speech-to-Text (STT) API calls. If a specified language parameter is not supported by the model or provider, the system will now intelligently retry the transcription without the language parameter or with a default, ensuring greater reliability and preventing failed API calls. This also specifically adds and refines support for the audio language parameter in Deepgram API integrations.
|
||||
- ⚡ **Optimized Hybrid Search Performance for BM25 Weight Configuration**: Enhanced hybrid search to significantly improve performance when the BM25 weight is set to 0 or less. This optimization intelligently disables unnecessary collection retrieval and BM25 ranking calculations, leading to faster search results without impacting accuracy for configurations that do not utilize lexical search contributions.
|
||||
- 🔒 **Configurable Code Interpreter Module Blacklist**: Introduced the "CODE_INTERPRETER_BLACKLISTED_MODULES" environment variable, allowing administrators to specify Python modules that are forbidden from being imported or executed within the code interpreter. This significantly enhances the security posture by mitigating risks associated with arbitrary code execution, such as unauthorized data access, system manipulation, or outbound connections.
|
||||
- 🔐 **Enhanced OAuth Role Claim Handling**: Improved compatibility with diverse OAuth providers by allowing role claims to be supplied as single strings or integers, in addition to arrays. The system now automatically normalizes these single-value claims into arrays for consistent processing, streamlining integration with identity providers that format role data differently.
|
||||
- ⚙️ **Configurable Tool Call Timeout**: Introduced the "AIOHTTP_CLIENT_TIMEOUT" environment variable, allowing administrators to specify custom timeout durations for external tool calls, which is crucial for integrations with tools that have varying or extended response times.
|
||||
- 🛠️ **Improved Tool Callable Generation for Google genai SDK**: Enhanced the creation of tool callables to directly support native function calling within the Google 'genai' SDK. This refactoring ensures proper signature inference and removes extraneous parameters, enabling seamless integration for advanced AI workflows using Google's generative AI models.
|
||||
- ✨ **Dynamic Loading of 'kokoro-js'**: Implemented dynamic loading for the 'kokoro-js' library, preventing failures and improving compatibility on older iOS browsers that may not support direct imports or certain modern JavaScript APIs like 'DecompressionStream'.
|
||||
- 🖥️ **Improved Command List Visibility on Small Screens**: Resolved an issue where the top items in command lists (e.g., Knowledge Base, Models, Prompts) were hidden or overlapped by the header on smaller screen sizes or specific browser zoom levels. The command option lists now dynamically adjust their height, ensuring all items are fully visible and accessible with proper scrolling.
|
||||
- 📦 **Improved Docker Image Compatibility for Arbitrary UIDs**: Fixed issues preventing the Open WebUI container from running in environments with arbitrary User IDs (UIDs), such as OpenShift's restricted Security Context Constraints (SCC). The Dockerfile has been updated to correctly set file system permissions for "/app" and "/root" directories, ensuring they are writable by processes running with a supplemental GID 0, thus resolving permission errors for Python libraries and application caches.
|
||||
- ♿ **Accessibility Enhancements**: Significantly improved the semantic structure of chat messages by using "section", "h2", "ul", and "li" HTML tags, and enhanced screen reader compatibility by explicitly hiding decorative images with "aria-hidden" attributes. This refactoring provides clearer structural context and improves overall accessibility and web standards compliance for the conversation flow.
|
||||
- 🌐 **Localization & Internationalization Improvements**: Significantly expanded internationalization support throughout the user interface, translating numerous user-facing strings in toast messages, placeholders, and other UI elements. This, alongside continuous refinement and expansion of translations for languages including Brazilian Portuguese, Kabyle (Taqbaylit), Czech, Finnish, Chinese (Simplified), Chinese (Traditional), and German, and general fixes for several other translation files, further enhances linguistic coverage and user experience.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛡️ **Resolved Critical OIDC SSO Login Failure**: Fixed a critical issue where OIDC Single Sign-On (SSO) logins failed due to an error in setting the authentication token as a cookie during the redirect process. This ensures reliable and seamless authentication for users utilizing OIDC providers, restoring full login functionality that was impacted by previous security hardening.
|
||||
- ⚡ **Prevented UI Blocking by Unreachable Webhooks**: Resolved a critical performance and user experience issue where synchronous webhook calls to unreachable or slow endpoints would block the entire user interface for all users. Webhook requests are now processed asynchronously using "aiohttp", ensuring that the UI remains responsive and functional even if webhook delivery encounters delays or failures.
|
||||
- 🔒 **Password Change Option Hidden for Externally Authenticated Users**: Resolved an issue where the password change dialog was visible to users authenticated via external methods (e.g., LDAP, OIDC, Trusted Header). The option to change a password in user settings is now correctly hidden for these users, as their passwords are managed externally, streamlining the user interface and preventing confusion.
|
||||
- 💬 **Resolved Temporary Chat and Permission Enforcement Issues**: Fixed a bug where temporary chats (identified by "chat_id = local") incorrectly triggered database checks, leading to 404 errors. This also resolves the issue where the 'USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED' setting was not functioning as intended, ensuring temporary chat mode now works correctly for user roles.
|
||||
- 🔐 **Admin Model Visibility for Administrators**: Private models remained visible and usable for administrators in the chat model selector, even when the intended privacy setting ("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS" - now renamed to "BYPASS_ADMIN_ACCESS_CONTROL") was disabled. This ensures consistent enforcement of model access controls and adherence to the principle of least privilege.
|
||||
- 🔍 **Clarified Web Search Engine Label for DDGS**: Addressed user confusion and inaccurate labeling by renaming "duckduckgo" to "DDGS" (Dux Distributed Global Search) in the web search engine selector. This clarifies that the system utilizes DDGS, a metasearch library that aggregates results from various search providers, accurately reflecting its underlying functionality rather than implying exclusive use of DuckDuckGo's search engine.
|
||||
- 🛠️ **Improved Settings UI Reactivity and Visibility**: Resolved an issue where settings tabs for 'Connections' and 'Tools' did not dynamically update their visibility based on global administrative feature flags (e.g., 'enable_direct_connections'). The UI now reactively shows or hides these sections, ensuring a consistent and clear experience when administrators control feature availability.
|
||||
- 🎚️ **Restored Model and Banner Reordering Functionality**: Fixed a bug that prevented administrators from reordering models in the Admin Panel's 'Models' settings and banners in the 'Interface' settings via drag-and-drop. The sortable functionality has been restored, allowing for proper customization of display order.
|
||||
- 📝 **Restored Custom Pending User Overlay Visibility**: Fixed an issue where the custom title and description configured for pending users were not visible. The application now correctly exposes these UI configuration settings to pending users, ensuring that the custom onboarding messages are displayed as intended.
|
||||
- 📥 **Fixed Community Function Import Compatibility**: Resolved an issue that prevented the successful import of function files downloaded from openwebui.com due to schema differences. The system now correctly processes these files, allowing for seamless integration of community-contributed functions.
|
||||
- 📦 **Fixed Stale Ollama Version in Docker Images**: Resolved an issue where the Ollama installation within Docker images could become stale due to caching during the build process. The Dockerfile now includes a mechanism to invalidate the build cache for the Ollama installation step, ensuring that the latest version of Ollama is always installed.
|
||||
- 🗄️ **Improved Milvus Query Handling for Large Datasets**: Fixed a "MilvusException" that occurred when attempting to query more than 16384 entries from a Milvus collection. The query logic has been refactored to use "query_iterator()", enabling efficient fetching of larger result sets in batches and resolving the previous limitation on the number of entries that could be retrieved.
|
||||
- 🐛 **Restored Message Toolbar Icons for Empty Messages with Files**: Fixed an issue where the edit, copy, and delete icons were not displayed on user messages that contained an attached file but no text content. This ensures full interaction capabilities for all message types, allowing users to manage their messages consistently.
|
||||
- 💬 **Resolved Streaming Interruption for Kimi-Dev Models**: Fixed an issue where streaming responses from Kimi-Dev models would halt prematurely upon encountering specific 'thinking' tokens (◁think▷, ◁/think▷). The system now correctly processes these tokens, ensuring uninterrupted streaming and proper handling of hidden or collapsible thinking sections.
|
||||
- 🔍 **Enhanced Knowledge Base Search Functionality**: Improved the search capability within the 'Knowledge' section of the Workspace. Previously, searching for knowledge bases required exact term matches or starting with the first letter. Now, the search algorithm has been refined to allow broader, less exact matches, making it easier and more intuitive to find relevant knowledge bases.
|
||||
- 📝 **Resolved Chinese Input 'Enter' Key Issue (macOS & iOS Safari)**: Fixed a bug where pressing the 'Enter' key during text composition with Input Method Editors (IMEs) on macOS and iOS Safari browsers would prematurely send the message. The system now robustly handles the composition state by addressing a 'compositionend' event bug specific to Safari, ensuring a smooth and expected typing experience for users of various languages, including Chinese and Korean.
|
||||
- 🔐 **Resolved OAUTH_GROUPS_CLAIM Configuration Issue**: Fixed a bug where the "OAUTH_GROUPS_CLAIM" environment variable was not correctly parsed due to a typo in the configuration file. This ensures that OAuth group management features, including automatic group creation, now correctly utilize the specified claim from the identity provider, allowing for seamless integration with external user directories like Keycloak.
|
||||
- 🗄️ **Resolved Azure PostgreSQL pgvector Extension Permissions**: Fixed an issue preventing the creation of "pgvector" and "pgcrypto" extensions on Azure PostgreSQL Flexible Servers due to permission limitations (e.g., 'Only members of "azure_pg_admin" are allowed to use "CREATE EXTENSION"'). The extension creation process now includes a conditional check, ensuring seamless deployment and compatibility with Azure PostgreSQL environments even with restricted database user permissions.
|
||||
- 🛠️ **Improved Backend Path Resolution and Alembic Stability**: Fixed issues causing Alembic database migrations to fail due to incorrect path resolution within the application. By implementing canonical path resolution for core directories and refining Alembic configuration, the robustness and correctness of internal pathing have been significantly enhanced, ensuring reliable database operations.
|
||||
- 📊 **Resolved Arena Model Identification in Feedback History**: Fixed an issue where the model used for feedback in arena settings was incorrectly reported as 'arena-model' in the evaluation history. The system now correctly logs and displays the actual model ID that received the feedback, restoring clarity and enabling proper analysis of model performance in arena environments.
|
||||
- 🎨 **Resolved Icon Overlap in 'Her' Theme**: Fixed a visual glitch in the 'Her' theme where icons would overlap on the loading screen and certain icons appeared incongruous. The display has been corrected to ensure proper visual presentation and theme consistency.
|
||||
- 🛠️ **Resolved Model Sorting TypeError with Null Names**: Fixed a "TypeError" that occurred in the "/api/models" endpoint when sorting models with null or missing names. The model sorting logic has been improved to gracefully handle such edge cases by ensuring that model IDs and names are treated as empty strings if their values are null or undefined, preventing comparison errors and improving API stability.
|
||||
- 💬 **Resolved Silently Dropped Streaming Response Chunks**: Fixed an issue where the final partial chunks of streaming chat responses could be silently dropped, leading to incomplete message delivery. The system now reliably flush any pending delta data upon stream termination, early breaks (e.g., code interpreter tags), or connection closure, ensuring complete and accurate response delivery.
|
||||
- 📱 **Disabled Overscroll for iOS Frontend**: Fixed an issue where overscrolling was enabled on iOS devices, causing unexpected scrolling behavior over fixed or sticky elements within the PWA. Overscroll has now been disabled, providing a more native application-like experience for iOS users.
|
||||
- 📝 **Resolved Code Block Input Issue with Shift+Enter**: Fixed a bug where typing three backticks followed by a language and then pressing Shift+Enter would cause the code block prefix to disappear, preventing proper code formatting. The system now correctly preserves the code block syntax, ensuring consistent behavior for multi-line code input.
|
||||
- 🛠️ **Improved OpenAI Model List Handling for Null Names**: Fixed an edge case where some OpenAI-compatible API providers might return models with a null value for their 'name' field. This could lead to issues like broken model list sorting. The system now gracefully handles these instances by removing the null 'name' key, ensuring stable model retrieval and display.
|
||||
- 🔍 **Resolved DDGS Concurrent Request Configuration**: Fixed an issue where the configured number of concurrent requests was not being honored for the DDGS (Dux Distributed Global Search) metasearch engine. The system now correctly applies the specified concurrency setting, improving efficiency for web searches.
|
||||
- 🛠️ **Improved Tool List Synchronization in Multi-Replica Deployments**: Resolved an issue where tool updates were not consistently reflected across all instances in multi-replica environments, leading to stale tool lists for users on other replicas. The tool list in the message input menu is now automatically refreshed each time it is accessed, ensuring all users always see the most current set of available tools.
|
||||
- 🛠️ **Resolved Duplicate Tool Name Collision**: Fixed an issue where tools with identical names from different external servers were silently removed, preventing their simultaneous use. The system now correctly handles tool name collisions by internally prefixing tools with their server identifier, allowing multiple instances of similarly named tools from different servers to be active and usable by LLMs.
|
||||
- 🖼️ **Resolved Image Generation API Size Parameter Issue**: Fixed a bug where the "/api/v1/images/generations" API endpoint did not correctly apply the 'size' parameter specified in the request payload for image generation. The system now properly honors the requested image dimensions (e.g., '1980x1080'), ensuring that generated images match the user's explicit size preference rather than defaulting to settings.
|
||||
- 🗄️ **Resolved S3 Vector Upload Limitations**: Fixed an issue that prevented uploading more than 500 vectors to S3 Vector buckets due to API limitations, which resulted in a "ValidationException". S3 vector uploads are now batched in groups of 500, ensuring successful processing of larger datasets.
|
||||
- 🛠️ **Fixed Tool Installation Error During Startup**: Resolved a "NoneType" error that occurred during tool installation at startup when 'tool.user' was unexpectedly null. The system now includes a check to ensure 'tool.user' exists before attempting to access its properties, preventing crashes and ensuring robust tool initialization.
|
||||
- 🛠️ **Improved Azure OpenAI GPT-5 Parameter Handling**: Fixed an issue with Azure OpenAI SDK parameter handling to correctly support GPT-5 models. The 'max_tokens' parameter is now appropriately converted to 'max_completion_tokens' for GPT-5 models, ensuring consistent behavior and proper function execution similar to existing o-series models.
|
||||
- 🐛 **Resolved Exception with Missing Group Permissions**: Fixed an exception that occurred in the access control logic when group permission objects were missing or null. The system now correctly handles cases where groups may not have explicit permission definitions, ensuring that 'None' checks prevent errors and maintain application stability when processing user permissions.
|
||||
- 🛠️ **Improved OpenAI API Base URL Handling**: Fixed an issue where a trailing slash in the 'OPENAI_API_BASE_URL' configuration could lead to models not being detected or the endpoint failing. The system now automatically removes trailing slashes from the configured URL, ensuring robust and consistent connections to OpenAI-compatible APIs.
|
||||
- 🖼️ **Resolved S3-Compatible Storage Upload Failures**: Fixed an issue where uploads to S3-compatible storage providers would fail with an "XAmzContentSHA256Mismatch" error. The system now correctly handles checksum calculations, ensuring reliable file and image uploads to S3-compatible services.
|
||||
- 🌐 **Corrected 'Releases' Link**: Fixed an issue where the 'Releases' button in the user menu directed to an incorrect URL, now correctly linking to the Open WebUI GitHub releases page.
|
||||
- 🛠️ **Resolved Model Sorting Errors with Null or Undefined Names**: Fixed multiple "TypeError" instances that occurred when attempting to sort model lists where model names were null or undefined. The sorting logic across various UI components (including Ollama model selection, leaderboard, and admin model settings) has been made more robust by gracefully handling absent model names, preventing crashes and ensuring consistent alphabetical sorting based on available name or ID.
|
||||
- 🎨 **Resolved Banner Dismissal Issue with Iteration IDs**: Fixed a bug where dismissing banners could lead to unintended multiple banner dismissals or other incorrect behavior, especially when banners lacked unique iteration IDs. Unique IDs are now assigned during banner iteration, ensuring proper individual dismissal and consistent display behavior.
|
||||
|
||||
### Changed
|
||||
|
||||
- 🛂 **Environment Variable for Admin Access Control**: The environment variable "ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS" has been renamed to "BYPASS_ADMIN_ACCESS_CONTROL". This new name more accurately reflects its function as a control to allow administrators to bypass model access restrictions. Users are encouraged to update their configurations to use the new variable name; existing configurations using the old name will still be honored for backward compatibility.
|
||||
- 🗂️ **Core Directory Path Resolution Updated**: The internal mechanism for resolving core application directory paths ("OPEN_WEBUI_DIR", "BACKEND_DIR", "BASE_DIR") has been updated to use canonical resolution via "Path().resolve()". This change improves path reliability but may require adjustments for any external scripts or configurations that previously relied on specific non-canonical path interpretations.
|
||||
- 🗃️ **Database Performance Options**: New database performance options, "DATABASE_ENABLE_SQLITE_WAL" and "DATABASE_DEDUPLICATE_INTERVAL", are now available. If "DATABASE_ENABLE_SQLITE_WAL" is enabled, SQLite will operate in WAL mode, which may alter SQLite's file locking behavior. If "DATABASE_DEDUPLICATE_INTERVAL" is set to a non-zero value, the "user.last_active_at" timestamp will be updated less frequently, leading to slightly less real-time accuracy for this specific field but significantly reducing database write conflicts and improving overall performance. Both options are disabled by default.
|
||||
- 🌐 **Renamed Web Search Concurrency Setting**: The environment variable "WEB_SEARCH_CONCURRENT_REQUESTS" has been renamed to "WEB_LOADER_CONCURRENT_REQUESTS". This change clarifies its scope, explicitly applying to the concurrency of the web loader component (which fetches content from search results) rather than the initial search engine query. Users relying on the old environment variable name for configuring web search concurrency must update their configurations to use "WEB_LOADER_CONCURRENT_REQUESTS".
|
||||
|
||||
## [0.6.22] - 2025-08-11
|
||||
|
||||
### Added
|
||||
|
||||
- 🔗 **OpenAI API '/v1' Endpoint Compatibility**: Enhanced API compatibility by supporting requests to paths like '/v1/models', '/v1/embeddings', and '/v1/chat/completions'. This allows Open WebUI to integrate more seamlessly with tools that expect OpenAI's '/v1' API structure.
|
||||
- 🪄 **Toggle for Guided Response Regeneration Menu**: Introduced a new setting in 'Interface' settings, providing the ability to enable or disable the expanded guided response regeneration menu. This offers users more control over their chat workflow and interface preferences.
|
||||
- ✨ **General UI/UX Enhancements**: Implemented various user interface and experience improvements, including more rounded corners for cards in the Knowledge, Prompts, and Tools sections, and minor layout adjustments within the chat Navbar for improved visual consistency.
|
||||
- 🌐 **Localization & Internationalization Improvements**: Introduced support for the Kabyle (Taqbaylit) language, refined and expanded translations for Chinese, expanding the platform's linguistic coverage.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🐞 **OpenAI Error Message Propagation**: Resolved an issue where specific OpenAI API errors (e.g., 'Organization Not Verified') were obscured by generic 'JSONResponse' iterable errors. The system now correctly propagates detailed and actionable error messages from OpenAI to the user.
|
||||
- 🌲 **Pinecone Insert Issue**: Fixed a bug that prevented proper insertion of items into Pinecone vector databases.
|
||||
- 📦 **S3 Vector Issue**: Resolved a bug where s3vector functionality failed due to incorrect import paths.
|
||||
- 🏠 **Landing Page Option Setting Not Working**: Fixed an issue where the landing page option in settings was not functioning as intended.
|
||||
|
||||
## [0.6.21] - 2025-08-10
|
||||
|
||||
### Added
|
||||
|
||||
- 👥 **User Groups in Edit Modal**: Added display of user groups information in the user edit modal, allowing administrators to view and manage group memberships directly when editing a user.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🐞 **Chat Completion 'model_id' Error**: Resolved a critical issue where chat completions failed with an "undefined model_id" error after upgrading to version 0.6.20, ensuring all models now function correctly and reliably.
|
||||
- 🛠️ **Audit Log User Information Logging**: Fixed an issue where user information was not being correctly logged in the audit trail due to an unreflected function prototype change, ensuring complete logging for administrative oversight.
|
||||
- 🛠️ **OpenTelemetry Configuration Consistency**: Fixed an issue where OpenTelemetry metric and log exporters' 'insecure' settings did not correctly default to the general OpenTelemetry 'insecure' flag, ensuring consistent security configurations across all OpenTelemetry exports.
|
||||
- 📝 **Reply Input Content Display**: Fixed an issue where replying to a message incorrectly displayed '{{INPUT_CONTENT}}' instead of the actual message content, ensuring proper content display in replies.
|
||||
- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Catalan, Korean, Spanish and Irish, ensuring a more fluent and native experience for global users.
|
||||
|
||||
## [0.6.20] - 2025-08-10
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛠️ **Quick Actions "Add" Behavior**: Fixed a bug where using the "Add" button in Quick Actions would add the resulting message as the very first message in the chat, instead of appending it to the latest message.
|
||||
|
||||
## [0.6.19] - 2025-08-09
|
||||
|
||||
### Added
|
||||
|
||||
- ✨ **Modernized Sidebar and Major UI Refinements**: The main navigation sidebar has been completely redesigned with a modern, cleaner aesthetic, featuring a sticky header and footer to keep key controls accessible. Core sidebar logic, like the pinned models list, was also refactored into dedicated components for better performance and maintainability.
|
||||
- 🪄 **Guided Response Regeneration**: The "Regenerate" button has been transformed into a powerful new menu. You can now guide the AI's next attempt by suggesting changes in a text prompt, or use one-click options like "Try Again," "Add Details," or "More Concise" to instantly refine and reshape the response to better fit your needs.
|
||||
- 🛠️ **Improved Tool Call Handling for GPT-OSS Models**: Implemented robust handling for tool calls specifically for GPT-OSS models, ensuring proper function execution and integration.
|
||||
- 🛑 **Stop Button for Merge Responses**: Added a dedicated stop button to immediately halt the generation of merged AI responses, providing users with more control over ongoing outputs.
|
||||
- 🔄 **Experimental SCIM 2.0 Support**: Implemented SCIM 2.0 (System for Cross-domain Identity Management) protocol support, enabling enterprise-grade automated user and group provisioning from identity providers like Okta, Azure AD, and Google Workspace for seamless user lifecycle management. Configuration is managed securely via environment variables.
|
||||
- 🗂️ **Amazon S3 Vector Support**: You can now use Amazon S3 Vector as a high-performance vector database for your Retrieval-Augmented Generation (RAG) workflows. This provides a scalable, cloud-native storage option for users deeply integrated into the AWS ecosystem, simplifying infrastructure and enabling enterprise-scale knowledge management.
|
||||
- 🗄️ **Oracle 23ai Vector Search Support**: Added support for Oracle 23ai's new vector search capabilities as a supported vector database, providing a robust and scalable option for managing large-scale documents and integrating vector search with existing business data at the database level.
|
||||
- ⚡ **Qdrant Performance and Configuration Enhancements**: The Qdrant client has been significantly improved with faster data retrieval logic for 'get' and 'query' operations. New environment variables ('QDRANT_TIMEOUT', 'QDRANT_HNSW_M') provide administrators with finer control over query timeouts and HNSW index parameters, enabling better performance tuning for large-scale deployments.
|
||||
- 🔐 **Encrypted SQLite Database with SQLCipher**: You can now encrypt your entire SQLite database at rest using SQLCipher. By setting the 'DATABASE_TYPE' to 'sqlite+sqlcipher' and providing a 'DATABASE_PASSWORD', all data is transparently encrypted, providing an essential security layer for protecting sensitive information in self-hosted deployments. Note that this requires additional system libraries and the 'sqlcipher3-wheels' Python package.
|
||||
- 🚀 **Efficient Redis Connection Management**: Implemented a shared connection pool cache to reuse Redis connections, dramatically reducing the number of active clients. This prevents connection exhaustion errors, improves performance, and ensures greater stability in high-concurrency deployments and those using Redis Sentinel.
|
||||
- ⚡ **Batched Response Streaming for High Performance**: Dramatically improve performance and stability during high-speed response streaming by batching multiple tokens together before sending them to the client. A new 'Stream Delta Chunk Size' advanced parameter can be set per-model or in user/chat settings, significantly reducing CPU load on the server, Redis, and client, and preventing connection issues in high-concurrency environments.
|
||||
- ⚙️ **Global Batched Streaming Configuration**: Administrators can now set a system-wide default for response streaming using the new 'CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE' environment variable. This allows for global performance tuning, while still letting per-model or per-chat settings override the default for more granular control.
|
||||
- 🔎 **Advanced Chat Search with Status Filters**: Quickly find any conversation with powerful new search filters. You can now instantly narrow down your chats using prefixes like 'pinned:true', 'shared:true', and 'archived:true' directly in the search bar. An intelligent dropdown menu assists you by suggesting available filter options as you type, streamlining your workflow and making chat management more efficient than ever.
|
||||
- 🛂 **Granular Chat Controls Permissions**: Administrators can now manage chat settings with greater detail. The main "Chat Controls" permission now acts as a master switch, while new granular toggles for "Valves", "System Prompts", and "Advanced Parameters" allow for more specific control over which sections are visible to users inside the panel.
|
||||
- ✍️ **Formatting Toolbar for Chat Input**: Introduced a dedicated formatting toolbar for the rich text chat input field, providing users with more accessible options for text styling and editing, configurable via interface settings.
|
||||
- 📑 **Tabbed View for Multi-Model Responses**: You can now enable a new tabbed interface to view responses from multiple models. Instead of side-scrolling cards, this compact view organizes each model's response into its own tab, making it easier to compare outputs and saving vertical space. This feature can be toggled on or off in Interface settings.
|
||||
- ↕️ **Reorder Pinned Models via Drag-and-Drop**: You can now organize your pinned models in the sidebar by simply dragging and dropping them into your preferred order. This custom layout is saved automatically, giving you more flexible control over your workspace.
|
||||
- 📌 **Quick Model Unpin Shortcut**: You can now quickly unpin a model by holding the Shift key and hovering over it to reveal an instant unpin button, streamlining your workspace customization.
|
||||
- ⚡ **Improved Chat Input Performance**: The chat input is now significantly more responsive, especially when pasting or typing large amounts of text. This was achieved by implementing a debounce mechanism for the auto-save feature, which prevents UI lag and ensures a smooth, uninterrupted typing experience.
|
||||
- ✍️ **Customizable Floating Quick Actions with Tool Support**: Take full control of your text interaction workflow with new customizable floating quick actions. In Settings, you can create, edit, or disable these actions and even integrate tools using the '{{TOOL:tool_id}}' syntax in your prompts, enabling powerful one-click automations on selected text. This is in addition to using placeholders like '{{CONTENT}}' and '{{INPUT_CONTENT}}' for custom text transformations.
|
||||
- 🔒 **Admin Workspace Privacy Control**: Introduced the 'ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS' environment variable (defaults to 'True') allowing administrators to control their access privileges to workspace items (Knowledge, Models, Prompts, Tools). When disabled, administrators adhere to the same access control rules as regular users, enhancing data separation for multi-tenant deployments.
|
||||
- 🗄️ **Comprehensive Model Configuration Management**: Administrators can now export the entire model configuration to a file and use a new declarative sync endpoint to manage models in bulk. This powerful feature enables seamless backups, migrations, and state replication across multiple instances.
|
||||
- 📦 **Native Redis Cluster Mode Support**: Added full support for connecting to Redis in cluster mode, allowing for scalable and highly available Redis deployments beyond Sentinel-managed setups. New environment variables 'REDIS_CLUSTER' and 'WEBSOCKET_REDIS_CLUSTER' enable the use of 'redis.cluster.RedisCluster' clients.
|
||||
- 📊 **Granular OpenTelemetry Metrics Configuration**: Introduced dedicated environment variables and enhanced configuration options for OpenTelemetry metrics, allowing for separate OTLP endpoints, basic authentication credentials, and protocol (HTTP/gRPC) specifically for metrics export, independent of trace settings. This provides greater flexibility for integrating with diverse observability stacks.
|
||||
- 🪵 **Granular OpenTelemetry Logging Configuration**: Enhanced the OpenTelemetry logging integration by introducing dedicated environment variables for logs, allowing separate OTLP endpoints, basic authentication credentials, and protocol (HTTP/gRPC) specifically for log export, independent of general OTel settings. The application's default Python logger now leverages this configuration to automatically send logs to your OTel endpoint when enabled via 'ENABLE_OTEL_LOGS'.
|
||||
- 📁 **Enhanced Folder Chat Management with Sorting and Time Blocks**: The chat list within folders now supports comprehensive sorting options by title and updated time, along with intelligent time-based grouping (e.g., "Today," "Yesterday") similar to the main chat view, making navigation and organization of project-specific conversations significantly easier.
|
||||
- ⚙️ **Configurable Datalab Marker API & Advanced Processing Options**: Enhanced Datalab Marker API integration, allowing administrators to configure custom API base URLs for self-hosting and to specify comprehensive processing options via a new 'additional_config' JSON parameter. This replaces the deprecated language selection feature and provides granular control over document extraction, with streamlined API endpoint resolution for more robust self-hosted deployments.
|
||||
- 🧑💼 **Export All Users to CSV**: Administrators can now export a complete list of all users to a CSV file directly from the Admin Panel's database settings. This provides a simple, one-click way to generate user data for auditing, reporting, or management purposes.
|
||||
- 🛂 **Customizable OAuth 'sub' Claim**: Administrators can now use the 'OAUTH_SUB_CLAIM_OVERRIDE' environment variable to specify which claim from the identity provider should be used as the unique user identifier ('sub'). This provides greater flexibility and control for complex enterprise authentication setups where modifying the IDP's default claims is not possible.
|
||||
- 👁️ **Password Visibility Toggle for Input Fields**: Password fields across the application (login, registration, user management, and account settings) now utilize a new 'SensitiveInput' component, providing a consistent toggle to reveal/hide passwords for improved usability and security.
|
||||
- 🛂 **Optional "Confirm Password" on Sign-Up**: To help prevent password typos during account creation, administrators can now enable a "Confirm Password" field on the sign-up page. This feature is disabled by default and can be activated via an environment variable for enhanced user experience.
|
||||
- 💬 **View Full Chat from User Feedback**: Administrators can now easily navigate to the full conversation associated with a user feedback entry directly from the feedback modal, streamlining the review and troubleshooting process.
|
||||
- 🎚️ **Intuitive Hybrid Search BM25-Weight Slider**: The numerical input for the BM25-Weight parameter in Hybrid Search has been replaced with an interactive slider, offering a more intuitive way to adjust the balance between lexical and semantic search. A "Default/Custom" toggle and clearer labels enhance usability and understanding of this key parameter.
|
||||
- ⚙️ **Enhanced Bulk Function Synchronization**: The API endpoint for synchronizing functions has been significantly improved to reliably handle bulk updates. This ensures that importing and managing large libraries of functions is more robust and error-free for administrators.
|
||||
- 🖼️ **Option to Disable Image Compression in Channels**: Introduced a new setting under Interface options to allow users to force-disable image compression specifically for images posted in channels, ensuring higher resolution for critical visual content.
|
||||
- 🔗 **Custom CORS Scheme Support**: Introduced a new environment variable 'CORS_ALLOW_CUSTOM_SCHEME' that allows administrators to define custom URL schemes (e.g., 'app://') for CORS origins, enabling greater flexibility for local development or desktop client integrations.
|
||||
- ♿ **Translatable and Accessible Banners**: Enhanced banner elements with translatable badge text and proper ARIA attributes (aria-label, aria-hidden) for SVG icons, significantly improving accessibility and screen reader compatibility.
|
||||
- ⚠️ **OAuth Configuration Warning for Missing OPENID_PROVIDER_URL**: Added a proactive startup warning that notifies administrators when OAuth providers (Google, Microsoft, or GitHub) are configured but the essential 'OPENID_PROVIDER_URL' environment variable is missing. This prevents silent OAuth logout failures and guides administrators to complete their setup correctly.
|
||||
- ♿ **Major Accessibility Enhancements**: Key parts of the interface have been made significantly more accessible. The user profile menu is now fully navigable via keyboard, essential controls in the Playground now include proper ARIA labels for screen readers, and decorative images have been hidden from assistive technologies to reduce audio clutter. Menu buttons also feature enhanced accessibility with 'aria-label', 'aria-hidden' for SVGs, and 'aria-pressed' for toggle buttons.
|
||||
- ⚙️ **General Backend Refactoring**: Implemented various backend improvements to enhance performance, stability, and security, ensuring a more resilient and reliable platform for all users, including refining logging output to be cleaner and more efficient by conditionally including 'extra_json' fields and improving consistent metadata handling in vector database operations, and laying preliminary scaffolding for future analytics features.
|
||||
- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Catalan, Danish, Korean, Persian, Polish, Simplified Chinese, and Spanish, ensuring a more fluent and native experience for global users across all supported languages.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🛡️ **Hardened Channel Message Security**: Fixed a key permission flaw that allowed users with channel access to edit or delete messages belonging to others. The system now correctly enforces that users can only modify their own messages, protecting data integrity in shared channels.
|
||||
- 🛡️ **Hardened OAuth Security by Removing JWT from URL**: Fixed a critical security vulnerability where the authentication token was exposed in the URL after a successful OAuth login. The token is now transferred via a browser cookie, preventing potential leaks through browser history or server logs and protecting user sessions.
|
||||
- 🛡️ **Hardened Chat Completion API Security**: The chat completion API endpoint now includes an explicit ownership check, ensuring non-admin users cannot access chats that do not belong to them and preventing potential unauthorized access.
|
||||
- 🛠️ **Resilient Model Loading**: Fixed an issue where a failure in loading the model list (e.g., from a misconfigured provider) would prevent the entire user interface, including the admin panel, from loading. The application now gracefully handles these errors, ensuring the UI remains accessible.
|
||||
- 🔒 **Resolved FIPS Self-Test Failure**: Fixed a critical issue that prevented Open WebUI from running on FIPS-compliant systems, specifically resolving the "FATAL FIPS SELFTEST FAILURE" error related to OpenSSL and SentenceTransformers, restoring compatibility with secure environments.
|
||||
- 📦 **Redis Cluster Connection Restored**: Fixed an issue where the backend was unable to connect to Redis in cluster mode, now ensuring seamless integration with scalable Redis cluster deployments.
|
||||
- 📦 **PGVector Connection Stability**: Fixed an issue where read-only operations could leave database transactions idle, preventing potential connection errors and improving overall database stability and resource management.
|
||||
- 🛠️ **OpenAPI Tool Integration for Array Parameters Fixed**: Resolved a critical bug where external tools using array parameters (e.g., for tags) would fail when used with OpenAI models. The system now correctly generates the required 'items' property in the function schema, restoring functionality and preventing '400 Bad Request' errors.
|
||||
- 🛠️ **Tool Creation for Users Restored**: Fixed a bug in the code editor where status messages were incorrectly prepended to tool scripts, causing a syntax error upon saving. All authorized users can now reliably create and save new tools.
|
||||
- 📁 **Folder Knowledge Processing Restored**: Fixed a bug where files uploaded to folder and model knowledge bases were not being extracted or analyzed for Retrieval-Augmented Generation (RAG) when the 'Max Upload Count' setting was empty, ensuring seamless document processing and knowledge augmentation.
|
||||
- 🧠 **Custom Model Knowledge Base Updates Recognized**: Fixed a bug where custom models linked to to knowledge bases did not automatically recognize newly added files to those knowledge bases. Models now correctly incorporate the latest information from updated knowledge collections.
|
||||
- 📦 **Comprehensive Redis Key Prefixing**: Corrected hardcoded prefixes to ensure the REDIS_KEY_PREFIX is now respected across all WebSocket and task management keys. This prevents data collisions in multi-instance deployments and improves compatibility with Redis cluster mode.
|
||||
- ✨ **More Descriptive OpenAI Router Errors**: The OpenAI-compatible API router now propagates detailed upstream error messages instead of returning a generic 'Bad Request'. This provides clear, actionable feedback for developers and API users, making it significantly easier to debug and resolve issues with model requests.
|
||||
- 🔐 **Hardened OIDC Signout Flow**: The OpenID Connect signout process now verifies that the 'OPENID_PROVIDER_URL' is configured before attempting to communicate with it, preventing potential errors and ensuring a more reliable logout experience.
|
||||
- 🍓 **Raspberry Pi Compatibility Restored**: Pinned the pyarrow library to version 20.0.0, resolving an "Illegal Instruction" crash on ARM-based devices like the Raspberry Pi and ensuring stable operation on this hardware.
|
||||
- 📁 **Folder System Prompt Variables Restored**: Fixed a bug where prompt variables (e.g., '{{CURRENT_DATETIME}}') were not being rendered in Folder-level System Prompts. This restores an important capability for creating dynamic, context-aware instructions for all chats within a project folder.
|
||||
- 📝 **Note Access in Knowledge Retrieval Fixed**: Corrected a permission oversight in knowledge retrieval, ensuring users can always use their own notes as a source for RAG without needing explicit sharing permissions.
|
||||
- 🤖 **Title Generation Compatibility for GPT-5 Models**: Added support for 'gpt-5' models in the payload handler, which correctly converts the deprecated 'max_tokens' parameter to 'max_completion_tokens'. This resolves title generation failures and ensures seamless operation with the latest generation of models.
|
||||
- ⚙️ **Correct API 'finish_reason' in Streaming Responses**: Fixed an issue where intermediate 'reasoning_content' chunks in streaming API responses incorrectly reported a 'finish_reason' of 'stop'. The 'finish_reason' is now correctly set to 'null' for these chunks, ensuring compatibility with third-party applications that rely on this field.
|
||||
- 📈 **Evaluation Pages Stability**: Resolved a crash on the Leaderboard and Feedbacks pages when processing legacy feedback entries that were missing a 'rating' field. The system now gracefully handles this older data, ensuring both pages load reliably for all users.
|
||||
- 🤝 **Reliable Collaborative Session Cleanup**: Fixed an asynchronous bug in the real-time collaboration engine that prevented document sessions from being properly cleaned up after all users had left. This ensures greater stability and resource management for features like Collaborative Notes.
|
||||
- 🧠 **Enhanced Memory Stability and Security**: Refactored memory update and delete operations to strictly enforce user ownership, preventing potential data integrity issues. Additionally, improved error handling for memory queries now provides clearer feedback when no memories exists.
|
||||
- 🧑⚖️ **Restored Admin Access to User Feedback**: Fixed a permission issue that blocked administrators from viewing or editing user feedback they didn't create, ensuring they can properly manage all evaluations across the platform.
|
||||
- 🔐 **PGVector Encryption Fix for Metadata**: Corrected a SQL syntax error in the experimental 'PGVECTOR_PGCRYPTO' feature that prevented encrypted metadata from being saved. Document uploads to encrypted PGVector collections now work as intended.
|
||||
- 🔍 **Serply Web Search Integration Restored**: Fixed an issue where incorrect parameters were passed to the Serply web search engine, restoring its functionality for RAG and web search workflows.
|
||||
- 🔍 **Resilient Web Search Processing**: Web search retrieval now gracefully handles search results that are missing a 'snippet', preventing crashes and ensuring that RAG workflows complete successfully even with incomplete data from search engines.
|
||||
- 🖼️ **Table Pasting in Rich Text Input Displayed Correctly**: Fixed an issue where pasting table text into the rich text input would incorrectly display it as code. Tables are now properly rendered as expected, improving content formatting and user experience.
|
||||
- ✍️ **Rich Text Input TypeError Resolution**: Addressed a potential 'TypeError: ue.getWordAtDocPos is not a function' in 'MessageInput.svelte' by refactoring how the 'getWordAtDocPos' function is accessed and referenced from 'RichTextInput.svelte', ensuring stable rich text input behavior, especially after production restarts.
|
||||
- ✏️ **Manual Code Block Creation in Chat Restored**: Fixed an issue where typing three backticks and then pressing Shift+Enter would incorrectly remove the backticks when "Enter to Send" mode was active. This ensures users can reliably create multi-line code blocks manually.
|
||||
- 🎨 **Consistent Dark Mode Background**: Fixed an issue where the application background could incorrectly flash or remain white during page loads and refreshes in dark mode, ensuring a seamless and consistent visual experience.
|
||||
- 🎨 **'Her' Theme Rendering Fixed**: Corrected a bug that caused the "Her" theme to incorrectly render as a dark theme in some situations. The theme now reliably applies its intended light appearance across all sessions.
|
||||
- 📜 **Corrected Markdown Table Line Break Rendering**: Fixed an issue where line breaks ('<br>') within Markdown tables were displayed as raw HTML instead of being rendered correctly. This ensures that tables with multi-line cell content are now displayed as intended.
|
||||
- 🚦 **Corrected App Configuration for Pending Users**: Fixed an issue where users awaiting approval could incorrectly load the full application interface, leading to a confusing or broken UI. This ensures that only fully approved users receive the standard app 'config', resulting in a smoother and more reliable onboarding experience.
|
||||
- 🔄 **Chat Cloning Now Includes Tags, Folder Status, and Pinned Status**: When cloning a chat or shared chat, its associated tags, folder organization, and pinned status are now correctly replicated, ensuring consistent chat management.
|
||||
- ⚙️ **Enhanced Backend Reliability**: Resolved a potential crash in knowledge base retrieval when referencing a deleted note. Additionally, chat processing was refactored to ensure model information is saved more reliably, enhancing overall system stability.
|
||||
- ⚙️ **Floating 'Ask/Explain' Modal Stability**: Fixed an issue that spammed the console with errors when navigating away while a model was generating a response in the floating 'Ask' or 'Explain' modals. In-flight requests are now properly cancelled, improving application stability.
|
||||
- ⚡ **Optimized User Count Checks**: Improved performance for user count and existence checks across the application by replacing resource-intensive 'COUNT' queries with more efficient 'EXISTS' queries, reducing database load.
|
||||
- 🔐 **Hardened OpenTelemetry Exporter Configuration**: The OTLP HTTP exporter no longer uses a potentially insecure explicit flag, improving security by relying on the connection URL's protocol (HTTP/HTTPS) to ensure transport safety.
|
||||
- 📱 **Mobile User Menu Closing Behavior Fixed**: Resolved an issue where the user menu would remain open on mobile devices after selecting an option, ensuring the menu correctly closes and returns focus to the main interface for a smoother mobile experience.
|
||||
- 📱 **OnBoarding Page Display Fixed on Mobile**: Resolved an issue where buttons on the OnBoarding page were not consistently visible on certain mobile browsers, ensuring a functional and complete user experience across devices.
|
||||
- ↕️ **Improved Pinned Models Drag-and-Drop Behavior**: The drag-and-drop functionality for reordering pinned models is now explicitly disabled on mobile devices, ensuring better usability and preventing potential UI conflicts or unexpected behavior.
|
||||
- 📱 **PWA Rotation Behavior Corrected**: The Progressive Web App now correctly respects the device's screen orientation lock, preventing unwanted rotation and ensuring a more native mobile experience.
|
||||
- ✏️ **Improved Chat Title Editing Behavior**: Changes to a chat title are now reliably saved when the user clicks away or presses Enter, replacing a less intuitive behavior that could accidentally discard edits. This makes renaming chats a smoother and more predictable experience.
|
||||
- ✏️ **Underscores Allowed in Prompt Commands**: Fixed the validation for prompt commands to correctly allow the use of underscores ('\_'), aligning with documentation examples and improving flexibility in naming custom prompts.
|
||||
- 💡 **Title Generation Button Behavior Fixed**: Resolved an issue where clicking the "Generate Title" button while editing a chat or note title would incorrectly save the title before generation could start. The focus is now managed correctly, ensuring a smooth and predictable user experience.
|
||||
- ✏️ **Consistent Chat Input Height**: Fixed a minor visual bug where the chat input field's height would change slightly when toggling the "Rich Text Input for Chat" setting, ensuring a more stable and consistent layout.
|
||||
- 🙈 **Admin UI Toggle Stability**: Fixed a visual glitch in the Admin settings where toggle switches could briefly display an incorrect state on page load, ensuring the UI always accurately reflects the saved settings.
|
||||
- 🙈 **Community Sharing Button Visibility**: The "Share to Community" button on the feedback page is now correctly hidden when the Enable Community Sharing feature is disabled in the admin settings, ensuring the UI respects the configured sharing policy.
|
||||
- 🙈 **"Help Us Translate" Link Visibility**: The "Help us translate" link in settings is now correctly hidden in deployments with specific license configurations, ensuring a cleaner interface for enterprise users.
|
||||
- 🔗 **Robust Tool Server URL Handling**: Fixed an issue where providing a full URL for a tool server's OpenAPI specification resulted in an invalid path. The system now correctly handles both absolute URLs and relative paths, improving configuration flexibility.
|
||||
- 🔧 **Improved Azure URL Detection**: The logic for identifying Azure OpenAI endpoints has been made more robust, ensuring all valid Azure URLs are now correctly detected for a smoother connection setup.
|
||||
- ⚙️ **Corrected Direct Connection Save Logic**: Fixed a bug in the Admin Connections settings page by removing a redundant save action for 'Direct Connections', leading to more reliable and predictable behavior when updating settings.
|
||||
- 🔗 **Corrected "Discover" Links**: The "Discover" links for models, prompts, tools, and functions now point to their specific, relevant pages on openwebui.com, improving content discovery for users.
|
||||
- ⏱️ **Refined Display of AI Thought Duration**: Adjusted the display logic for AI thought (reasoning) durations to more accurately show very short thought times as "less than a second," improving clarity in AI process feedback.
|
||||
- 📜 **Markdown Line Break Rendering Refinement**: Improved handling of line breaks within Markdown rendering for better visual consistency.
|
||||
- 🛠️ **Corrected OpenTelemetry Docker Compose Example**: The docker-compose.otel.yaml file has been fixed and enhanced by removing duplicates, adding necessary environment variables, and hardening security settings, ensuring a more reliable out-of-box observability setup.
|
||||
- 🛠️ **Development Script CORS Fix**: Corrected the CORS origin URL in the local development script (dev.sh) by removing the trailing slash, ensuring a more reliable and consistent setup for developers.
|
||||
- ⬆️ **OpenTelemetry Libraries Updated**: Upgraded all OpenTelemetry-related libraries to their latest versions, ensuring better performance, stability, and compatibility for observability.
|
||||
|
||||
### Changed
|
||||
|
||||
- ❗ **Docling Integration Upgraded to v1 API (Breaking Change)**: The integration with the Docling document processing engine has been updated to its new, stable '/v1' API. This is required for compatibility with Docling version 1.0.0 and newer. As a result, older versions of Docling are no longer supported. Users who rely on Docling for document ingestion **must upgrade** their docling-serve instance to ensure continued functionality.
|
||||
- 🗣️ **Admin-First Whisper Language Priority**: The global WHISPER_LANGUAGE setting now acts as a strict override for audio transcriptions. If set, it will be used for all speech-to-text tasks, ignoring any language specified by the user on a per-request basis. This gives administrators more control over transcription consistency.
|
||||
- ✂️ **Datalab Marker API Language Selection Removed**: The separate language selection option for the Datalab Marker API has been removed, as its functionality is now integrated and superseded by the more comprehensive 'additional_config' parameter. Users should transition to using 'additional_config' for relevant language and processing settings.
|
||||
- 📄 **Documentation and Releases Links Visibility**: The "Documentation" and "Releases" links in the user menu are now visible only to admin users, streamlining the user interface for non-admin roles.
|
||||
|
||||
## [0.6.18] - 2025-07-19
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🚑 **Users Not Loading in Groups**: Resolved an issue where user list was not displaying within user groups, restoring full visibility and management of group memberships for teams and admins.
|
||||
|
||||
## [0.6.17] - 2025-07-19
|
||||
|
||||
### Added
|
||||
|
||||
- 📂 **Dedicated Folder View with Chat List**: Clicking a folder now reveals a brand-new landing page showcasing a list of all chats within that folder, making navigation simpler and giving teams immediate visibility into project-specific conversations.
|
||||
- 🆕 **Streamlined Folder Creation Modal**: Creating a new folder is now a seamless, unified experience with a dedicated modal that visually and functionally matches the edit folder flow, making workspace organization more intuitive and error-free for all users.
|
||||
- 🗃️ **Direct File Uploads to Folder Knowledge**: You can now upload files straight to a folder’s knowledge—empowering you to enrich project spaces by adding resources and documents directly, without the need to pre-create knowledge bases beforehand.
|
||||
- 🔎 **Chat Preview in Search**: When searching chats, instantly preview results in context without having to open them—making discovery, auditing, and recall dramatically quicker, especially in large, active teams.
|
||||
- 🖼️ **Image Upload and Inline Insertion in Notes**: Notes now support inserting images directly among your text, letting you create rich, visually structured documentation, brainstorms, or reports in a more natural and engaging way—no more images just as attachments.
|
||||
- 📱 **Enhanced Note Selection Editing and Q&A**: Select any portion of your notes to either edit just the highlighted part or ask focused questions about that content—streamlining workflows, boosting productivity, and making reviews or AI-powered enhancements more targeted.
|
||||
- 📝 **Copy Notes as Rich Text**: Copy entire notes—including all formatting, images, and structure—directly as rich text for seamless pasting into emails, reports, or other tools, maintaining clarity and consistency outside the WebUI.
|
||||
- ⚡ **Fade-In Streaming Text Experience**: Live-generated responses now elegantly fade in as the AI streams them, creating a more natural and visually engaging reading experience; easily toggled off in Interface settings if you prefer static displays.
|
||||
- 🔄 **Settings for Follow-Up Prompts**: Fine-tune your follow-up prompt experience—with new controls, you can choose to keep them visible or have them inserted directly into the message input instead of auto-submitting, giving you more flexibility and control over your workflow.
|
||||
- 🔗 **Prompt Variable Documentation Quick Link**: Access documentation for prompt variables in one click from the prompt editor modal—shortening the learning curve and making advanced prompt-building more accessible.
|
||||
- 📈 **Active and Total User Metrics for Telemetry**: Gain valuable insights into usage patterns and platform engagement with new metrics tracking active and total users—enhancing auditability and planning for large organizations.
|
||||
- 🏷️ **Traceability with Log Trace and Span IDs**: Each log entry now carries detailed trace and span IDs, making it much easier for admins to pinpoint and resolve issues across distributed systems or in complex troubleshooting.
|
||||
- 👥 **User Group Add/Remove Endpoints**: Effortlessly add or remove users from groups with new, improved endpoints—giving admins and team leads faster, clearer control over collaboration and permissions.
|
||||
- ⚙️ **Note Settings and Controls Streamlined**: The main “Settings” for notes are now simply called “Controls”, and note files now reside in a dedicated controls section, decluttering navigation and making it easier to find and configure note-related options.
|
||||
- 🚀 **Faster Admin User Page Loads**: The user list endpoint for admins has been optimized to exclude heavy profile images, speeding up load times for large teams and reducing waiting during administrative tasks.
|
||||
- 📡 **Chat ID Header Forwarding**: Ollama and OpenAI router requests now include the chat ID in request headers, enabling better request correlation and debugging capabilities across AI model integrations.
|
||||
- 🧠 **Enhanced Reasoning Tag Processing**: Improved and expanded reasoning tag parsing to handle various tag formats more robustly, including standard XML-style tags and custom delimiters, ensuring better AI reasoning transparency and debugging capabilities.
|
||||
- 🔐 **OAuth Token Endpoint Authentication Method**: Added configurable OAuth token endpoint authentication method support, providing enhanced flexibility and security options for enterprise OAuth integrations and identity provider compatibility.
|
||||
- 🛡️ **Redis Sentinel High Availability Support**: Comprehensive Redis Sentinel failover implementation with automatic master discovery, intelligent retry logic for connection failures, and seamless operation during master node outages—eliminating single points of failure and ensuring continuous service availability in production deployments.
|
||||
- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Simplified Chinese, Traditional Chinese, French, German, Korean, and Polish, ensuring a more fluent and native experience for global users across all supported languages.
|
||||
|
||||
### Fixed
|
||||
|
||||
- 🏷️ **Hybrid Search Functionality Restored**: Hybrid search now works seamlessly again—enabling more accurate, relevant, and comprehensive knowledge discovery across all RAG-powered workflows.
|
||||
- 🚦 **Note Chat - Edit Button Disabled During AI Generation**: The edit button when chatting with a note is now disabled while the AI is responding—preventing accidental edits and ensuring workflow clarity during chat sessions.
|
||||
- 🧹 **Cleaner Database Credentials**: Database connection no longer duplicates ‘@’ in credentials, preventing potential connection issues and ensuring smoother, more reliable integrations.
|
||||
- 🧑💻 **File Deletion Now Removes Related Vector Data**: When files are deleted from storage, they are now purged from the vector database as well, ensuring clean data management and preventing clutter or stale search results.
|
||||
- 📁 **Files Modal Translation Issues Fixed**: All modal dialog strings—including “Using Entire Document” and “Using Focused Retrieval”—are now fully translated for a more consistent and localized UI experience.
|
||||
- 🚫 **Drag-and-Drop File Upload Disabled for Unsupported Models**: File upload by drag-and-drop is disabled when using models that do not support attachments—removing confusion and preventing workflow interruptions.
|
||||
- 🔑 **Ollama Tool Calls Now Reliable**: Fixed issues with Ollama-based tool calls, ensuring uninterrupted AI augmentation and tool use for every chat.
|
||||
- 📄 **MIME Type Help String Correction**: Cleaned up mimetype help text by removing extraneous characters, providing clearer guidance for file upload configurations.
|
||||
- 📝 **Note Editor Permission Fix**: Removed unnecessary admin-only restriction from note chat functionality, allowing all authorized users to access note editing features as intended.
|
||||
- 📋 **Chat Sources Handling Improved**: Fixed sources handling logic to prevent duplicate source assignments in chat messages, ensuring cleaner and more accurate source attribution during conversations.
|
||||
- 😀 **Emoji Generation Error Handling**: Improved error handling in audio router and fixed metadata structure for emoji generation tasks, preventing crashes and ensuring more reliable emoji generation functionality.
|
||||
- 🔒 **Folder System Prompt Permission Enforcement**: System prompt fields in folder edit modal are now properly hidden for users without system prompt permissions, ensuring consistent security policy enforcement across all folder management interfaces.
|
||||
- 🌐 **WebSocket Redis Lock Timeout Type Conversion**: Fixed proper integer type conversion for WebSocket Redis lock timeout configuration with robust error handling, preventing potential configuration errors and ensuring stable WebSocket connections.
|
||||
- 📦 **PostHog Dependency Added**: Added PostHog 5.4.0 library to resolve ChromaDB compatibility issues, ensuring stable vector database operations and preventing library version conflicts during deployment.
|
||||
|
||||
### Changed
|
||||
|
||||
- 👀 **Tiptap Editor Upgraded to v3**: The underlying rich text editor has been updated for future-proofing, though some supporting libraries remain on v2 for compatibility. For now, please install dependencies using 'npm install --force' to avoid installation errors.
|
||||
- 🚫 **Removed Redundant or Unused Strings and Elements**: Miscellaneous unused, duplicate, or obsolete code and translations have been cleaned up to maintain a streamlined and high-performance experience.
|
||||
|
||||
## [0.6.16] - 2025-07-14
|
||||
|
||||
### Added
|
||||
|
|
|
|||
49
Dockerfile
49
Dockerfile
|
|
@ -30,7 +30,7 @@ WORKDIR /app
|
|||
RUN apk add --no-cache git
|
||||
|
||||
COPY package.json package-lock.json ./
|
||||
RUN npm ci
|
||||
RUN npm ci --force
|
||||
|
||||
COPY . .
|
||||
ENV APP_BUILD_HASH=${BUILD_HASH}
|
||||
|
|
@ -108,29 +108,13 @@ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry
|
|||
# Make sure the user has access to the app and root directory
|
||||
RUN chown -R $UID:$GID /app $HOME
|
||||
|
||||
RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
||||
apt-get update && \
|
||||
# Install pandoc and netcat
|
||||
apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
|
||||
apt-get install -y --no-install-recommends gcc python3-dev && \
|
||||
# for RAG OCR
|
||||
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
||||
# install helper tools
|
||||
apt-get install -y --no-install-recommends curl jq && \
|
||||
# install ollama
|
||||
curl -fsSL https://ollama.com/install.sh | sh && \
|
||||
# cleanup
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
else \
|
||||
apt-get update && \
|
||||
# Install pandoc, netcat and gcc
|
||||
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
|
||||
apt-get install -y --no-install-recommends gcc python3-dev && \
|
||||
# for RAG OCR
|
||||
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
||||
# cleanup
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
# Install common system dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
git build-essential pandoc gcc netcat-openbsd curl jq \
|
||||
python3-dev \
|
||||
ffmpeg libsm6 libxext6 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# install python dependencies
|
||||
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
|
||||
|
|
@ -152,7 +136,13 @@ RUN pip3 install --no-cache-dir uv && \
|
|||
fi; \
|
||||
chown -R $UID:$GID /app/backend/data/
|
||||
|
||||
|
||||
# Install Ollama if requested
|
||||
RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
||||
date +%s > /tmp/ollama_build_hash && \
|
||||
echo "Cache broken at timestamp: `cat /tmp/ollama_build_hash`" && \
|
||||
curl -fsSL https://ollama.com/install.sh | sh && \
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# copy embedding weight from build
|
||||
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
||||
|
|
@ -170,6 +160,15 @@ EXPOSE 8080
|
|||
|
||||
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
|
||||
|
||||
# Minimal, atomic permission hardening for OpenShift (arbitrary UID):
|
||||
# - Group 0 owns /app and /root
|
||||
# - Directories are group-writable and have SGID so new files inherit GID 0
|
||||
RUN set -eux; \
|
||||
chgrp -R 0 /app /root || true; \
|
||||
chmod -R g+rwX /app /root || true; \
|
||||
find /app -type d -exec chmod g+s {} + || true; \
|
||||
find /root -type d -exec chmod g+s {} + || true
|
||||
|
||||
USER $UID:$GID
|
||||
|
||||
ARG BUILD_HASH
|
||||
|
|
|
|||
53
LICENSE_HISTORY
Normal file
53
LICENSE_HISTORY
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
All code and materials created before commit `60d84a3aae9802339705826e9095e272e3c83623` are subject to the following copyright and license:
|
||||
|
||||
Copyright (c) 2023-2025 Timothy Jaeryang Baek
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
All code and materials created before commit `a76068d69cd59568b920dfab85dc573dbbb8f131` are subject to the following copyright and license:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Timothy Jaeryang Baek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
16
README.md
16
README.md
|
|
@ -31,6 +31,8 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
|
|||
|
||||
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
||||
|
||||
- 🔄 **SCIM 2.0 Support**: Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
|
||||
|
||||
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
||||
|
||||
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
||||
|
|
@ -68,7 +70,7 @@ Want to learn more about Open WebUI's features? Check out our [Open WebUI docume
|
|||
#### Emerald
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<!-- <tr>
|
||||
<td>
|
||||
<a href="https://n8n.io/" target="_blank">
|
||||
<img src="https://docs.openwebui.com/sponsors/logos/n8n.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
|
||||
|
|
@ -77,7 +79,7 @@ Want to learn more about Open WebUI's features? Check out our [Open WebUI docume
|
|||
<td>
|
||||
<a href="https://n8n.io/">n8n</a> • Does your interface have a backend yet?<br>Try <a href="https://n8n.io/">n8n</a>
|
||||
</td>
|
||||
</tr>
|
||||
</tr> -->
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://tailscale.com/blog/self-host-a-local-ai-stack/?utm_source=OpenWebUI&utm_medium=paid-ad-placement&utm_campaign=OpenWebUI-Docs" target="_blank">
|
||||
|
|
@ -88,6 +90,16 @@ Want to learn more about Open WebUI's features? Check out our [Open WebUI docume
|
|||
<a href="https://tailscale.com/blog/self-host-a-local-ai-stack/?utm_source=OpenWebUI&utm_medium=paid-ad-placement&utm_campaign=OpenWebUI-Docs">Tailscale</a> • Connect self-hosted AI to any device with Tailscale
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://warp.dev/open-webui" target="_blank">
|
||||
<img src="https://docs.openwebui.com/sponsors/logos/warp.png" alt="Warp" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://warp.dev/open-webui">Warp</a> • The intelligent terminal for developers
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
---
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
export CORS_ALLOW_ORIGIN="http://localhost:5173"
|
||||
PORT="${PORT:-8080}"
|
||||
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
|
||||
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ script_location = migrations
|
|||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
prepend_sys_path = ..
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import redis
|
|||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from typing import Generic, Union, Optional, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
|
@ -168,9 +168,19 @@ class PersistentConfig(Generic[T]):
|
|||
self.config_path = config_path
|
||||
self.env_value = env_value
|
||||
self.config_value = get_config_value(config_path)
|
||||
|
||||
if self.config_value is not None and ENABLE_PERSISTENT_CONFIG:
|
||||
log.info(f"'{env_name}' loaded from the latest database entry")
|
||||
self.value = self.config_value
|
||||
if (
|
||||
self.config_path.startswith("oauth.")
|
||||
and not ENABLE_OAUTH_PERSISTENT_CONFIG
|
||||
):
|
||||
log.info(
|
||||
f"Skipping loading of '{env_name}' as OAuth persistent config is disabled"
|
||||
)
|
||||
self.value = env_value
|
||||
else:
|
||||
log.info(f"'{env_name}' loaded from the latest database entry")
|
||||
self.value = self.config_value
|
||||
else:
|
||||
self.value = env_value
|
||||
|
||||
|
|
@ -213,13 +223,14 @@ class PersistentConfig(Generic[T]):
|
|||
|
||||
class AppConfig:
|
||||
_state: dict[str, PersistentConfig]
|
||||
_redis: Optional[redis.Redis] = None
|
||||
_redis: Union[redis.Redis, redis.cluster.RedisCluster] = None
|
||||
_redis_key_prefix: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: Optional[str] = None,
|
||||
redis_sentinels: Optional[list] = [],
|
||||
redis_cluster: Optional[bool] = False,
|
||||
redis_key_prefix: str = "open-webui",
|
||||
):
|
||||
super().__setattr__("_state", {})
|
||||
|
|
@ -227,7 +238,12 @@ class AppConfig:
|
|||
if redis_url:
|
||||
super().__setattr__(
|
||||
"_redis",
|
||||
get_redis_connection(redis_url, redis_sentinels, decode_responses=True),
|
||||
get_redis_connection(
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster,
|
||||
decode_responses=True,
|
||||
),
|
||||
)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
|
|
@ -296,6 +312,9 @@ JWT_EXPIRES_IN = PersistentConfig(
|
|||
# OAuth config
|
||||
####################################
|
||||
|
||||
ENABLE_OAUTH_PERSISTENT_CONFIG = (
|
||||
os.environ.get("ENABLE_OAUTH_PERSISTENT_CONFIG", "True").lower() == "true"
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_SIGNUP = PersistentConfig(
|
||||
"ENABLE_OAUTH_SIGNUP",
|
||||
|
|
@ -445,6 +464,12 @@ OAUTH_TIMEOUT = PersistentConfig(
|
|||
os.environ.get("OAUTH_TIMEOUT", ""),
|
||||
)
|
||||
|
||||
OAUTH_TOKEN_ENDPOINT_AUTH_METHOD = PersistentConfig(
|
||||
"OAUTH_TOKEN_ENDPOINT_AUTH_METHOD",
|
||||
"oauth.oidc.token_endpoint_auth_method",
|
||||
os.environ.get("OAUTH_TOKEN_ENDPOINT_AUTH_METHOD", None),
|
||||
)
|
||||
|
||||
OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig(
|
||||
"OAUTH_CODE_CHALLENGE_METHOD",
|
||||
"oauth.oidc.code_challenge_method",
|
||||
|
|
@ -457,6 +482,12 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
|
|||
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
|
||||
)
|
||||
|
||||
OAUTH_SUB_CLAIM = PersistentConfig(
|
||||
"OAUTH_SUB_CLAIM",
|
||||
"oauth.oidc.sub_claim",
|
||||
os.environ.get("OAUTH_SUB_CLAIM", None),
|
||||
)
|
||||
|
||||
OAUTH_USERNAME_CLAIM = PersistentConfig(
|
||||
"OAUTH_USERNAME_CLAIM",
|
||||
"oauth.oidc.username_claim",
|
||||
|
|
@ -479,7 +510,7 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
|
|||
OAUTH_GROUPS_CLAIM = PersistentConfig(
|
||||
"OAUTH_GROUPS_CLAIM",
|
||||
"oauth.oidc.group_claim",
|
||||
os.environ.get("OAUTH_GROUP_CLAIM", "groups"),
|
||||
os.environ.get("OAUTH_GROUPS_CLAIM", os.environ.get("OAUTH_GROUP_CLAIM", "groups")),
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
|
||||
|
|
@ -636,6 +667,13 @@ def load_oauth_providers():
|
|||
def oidc_oauth_register(client: OAuth):
|
||||
client_kwargs = {
|
||||
"scope": OAUTH_SCOPES.value,
|
||||
**(
|
||||
{
|
||||
"token_endpoint_auth_method": OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value
|
||||
}
|
||||
if OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}
|
||||
),
|
||||
|
|
@ -667,6 +705,23 @@ def load_oauth_providers():
|
|||
"register": oidc_oauth_register,
|
||||
}
|
||||
|
||||
configured_providers = []
|
||||
if GOOGLE_CLIENT_ID.value:
|
||||
configured_providers.append("Google")
|
||||
if MICROSOFT_CLIENT_ID.value:
|
||||
configured_providers.append("Microsoft")
|
||||
if GITHUB_CLIENT_ID.value:
|
||||
configured_providers.append("GitHub")
|
||||
|
||||
if configured_providers and not OPENID_PROVIDER_URL.value:
|
||||
provider_list = ", ".join(configured_providers)
|
||||
log.warning(
|
||||
f"⚠️ OAuth providers configured ({provider_list}) but OPENID_PROVIDER_URL not set - logout will not work!"
|
||||
)
|
||||
log.warning(
|
||||
f"Set OPENID_PROVIDER_URL to your OAuth provider's OpenID Connect discovery endpoint to fix logout functionality."
|
||||
)
|
||||
|
||||
|
||||
load_oauth_providers()
|
||||
|
||||
|
|
@ -676,6 +731,17 @@ load_oauth_providers()
|
|||
|
||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
|
||||
|
||||
try:
|
||||
if STATIC_DIR.exists():
|
||||
for item in STATIC_DIR.iterdir():
|
||||
if item.is_file() or item.is_symlink():
|
||||
try:
|
||||
item.unlink()
|
||||
except Exception as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"):
|
||||
if file_path.is_file():
|
||||
target_path = STATIC_DIR / file_path.relative_to(
|
||||
|
|
@ -755,12 +821,6 @@ if CUSTOM_NAME:
|
|||
pass
|
||||
|
||||
|
||||
####################################
|
||||
# LICENSE_KEY
|
||||
####################################
|
||||
|
||||
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
|
||||
|
||||
####################################
|
||||
# STORAGE PROVIDER
|
||||
####################################
|
||||
|
|
@ -811,7 +871,7 @@ CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|||
ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
|
||||
"ENABLE_DIRECT_CONNECTIONS",
|
||||
"direct.enable",
|
||||
os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
|
||||
os.environ.get("ENABLE_DIRECT_CONNECTIONS", "False").lower() == "true",
|
||||
)
|
||||
|
||||
####################################
|
||||
|
|
@ -893,6 +953,9 @@ GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "")
|
|||
|
||||
if OPENAI_API_BASE_URL == "":
|
||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||
else:
|
||||
if OPENAI_API_BASE_URL.endswith("/"):
|
||||
OPENAI_API_BASE_URL = OPENAI_API_BASE_URL[:-1]
|
||||
|
||||
OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
|
||||
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
|
||||
|
|
@ -1125,10 +1188,18 @@ USER_PERMISSIONS_CHAT_CONTROLS = (
|
|||
os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_VALVES = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_VALVES", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_SYSTEM_PROMPT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_PARAMS = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_PARAMS", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
|
||||
)
|
||||
|
|
@ -1214,7 +1285,9 @@ DEFAULT_USER_PERMISSIONS = {
|
|||
},
|
||||
"chat": {
|
||||
"controls": USER_PERMISSIONS_CHAT_CONTROLS,
|
||||
"valves": USER_PERMISSIONS_CHAT_VALVES,
|
||||
"system_prompt": USER_PERMISSIONS_CHAT_SYSTEM_PROMPT,
|
||||
"params": USER_PERMISSIONS_CHAT_PARAMS,
|
||||
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
|
||||
"delete": USER_PERMISSIONS_CHAT_DELETE,
|
||||
"edit": USER_PERMISSIONS_CHAT_EDIT,
|
||||
|
|
@ -1281,6 +1354,18 @@ WEBHOOK_URL = PersistentConfig(
|
|||
|
||||
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
|
||||
|
||||
ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = (
|
||||
os.environ.get("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS", "True").lower() == "true"
|
||||
)
|
||||
|
||||
BYPASS_ADMIN_ACCESS_CONTROL = (
|
||||
os.environ.get(
|
||||
"BYPASS_ADMIN_ACCESS_CONTROL",
|
||||
os.environ.get("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS", "True"),
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
ENABLE_ADMIN_CHAT_ACCESS = (
|
||||
os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true"
|
||||
)
|
||||
|
|
@ -1319,10 +1404,11 @@ if THREAD_POOL_SIZE is not None and isinstance(THREAD_POOL_SIZE, str):
|
|||
def validate_cors_origin(origin):
|
||||
parsed_url = urlparse(origin)
|
||||
|
||||
# Check if the scheme is either http or https
|
||||
if parsed_url.scheme not in ["http", "https"]:
|
||||
# Check if the scheme is either http or https, or a custom scheme
|
||||
schemes = ["http", "https"] + CORS_ALLOW_CUSTOM_SCHEME
|
||||
if parsed_url.scheme not in schemes:
|
||||
raise ValueError(
|
||||
f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed."
|
||||
f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' and CORS_ALLOW_CUSTOM_SCHEME are allowed."
|
||||
)
|
||||
|
||||
# Ensure that the netloc (domain + port) is present, indicating it's a valid URL
|
||||
|
|
@ -1337,6 +1423,11 @@ def validate_cors_origin(origin):
|
|||
# in your .env file depending on your frontend port, 5173 in this case.
|
||||
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
|
||||
|
||||
# Allows custom URL schemes (e.g., app://) to be used as origins for CORS.
|
||||
# Useful for local development or desktop clients with schemes like app:// or other custom protocols.
|
||||
# Provide a semicolon-separated list of allowed schemes in the environment variable CORS_ALLOW_CUSTOM_SCHEMES.
|
||||
CORS_ALLOW_CUSTOM_SCHEME = os.environ.get("CORS_ALLOW_CUSTOM_SCHEME", "").split(";")
|
||||
|
||||
if CORS_ALLOW_ORIGIN == ["*"]:
|
||||
log.warning(
|
||||
"\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
|
||||
|
|
@ -1485,7 +1576,7 @@ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
|||
)
|
||||
|
||||
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
|
||||
Suggest 3-5 relevant follow-up questions or prompts in the chat's primary language that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
|
||||
### Guidelines:
|
||||
- Write all follow-up questions from the user’s point of view, directed to the assistant.
|
||||
- Make questions concise, clear, and directly related to the discussed topic(s).
|
||||
|
|
@ -1777,6 +1868,11 @@ CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig(
|
|||
),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_BLOCKED_MODULES = [
|
||||
library.strip()
|
||||
for library in os.environ.get("CODE_INTERPRETER_BLOCKED_MODULES", "").split(",")
|
||||
if library.strip()
|
||||
]
|
||||
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT = """
|
||||
#### Tools Available
|
||||
|
|
@ -1844,6 +1940,8 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
|
|||
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
|
||||
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
|
||||
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
|
||||
QDRANT_TIMEOUT = int(os.environ.get("QDRANT_TIMEOUT", "5"))
|
||||
QDRANT_HNSW_M = int(os.environ.get("QDRANT_HNSW_M", "16"))
|
||||
ENABLE_QDRANT_MULTITENANCY_MODE = (
|
||||
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
|
||||
)
|
||||
|
|
@ -1933,6 +2031,37 @@ PINECONE_DIMENSION = int(os.getenv("PINECONE_DIMENSION", 1536)) # or 3072, 1024
|
|||
PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine")
|
||||
PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure"
|
||||
|
||||
# ORACLE23AI (Oracle23ai Vector Search)
|
||||
|
||||
ORACLE_DB_USE_WALLET = os.environ.get("ORACLE_DB_USE_WALLET", "false").lower() == "true"
|
||||
ORACLE_DB_USER = os.environ.get("ORACLE_DB_USER", None) #
|
||||
ORACLE_DB_PASSWORD = os.environ.get("ORACLE_DB_PASSWORD", None) #
|
||||
ORACLE_DB_DSN = os.environ.get("ORACLE_DB_DSN", None) #
|
||||
ORACLE_WALLET_DIR = os.environ.get("ORACLE_WALLET_DIR", None)
|
||||
ORACLE_WALLET_PASSWORD = os.environ.get("ORACLE_WALLET_PASSWORD", None)
|
||||
ORACLE_VECTOR_LENGTH = os.environ.get("ORACLE_VECTOR_LENGTH", 768)
|
||||
|
||||
ORACLE_DB_POOL_MIN = int(os.environ.get("ORACLE_DB_POOL_MIN", 2))
|
||||
ORACLE_DB_POOL_MAX = int(os.environ.get("ORACLE_DB_POOL_MAX", 10))
|
||||
ORACLE_DB_POOL_INCREMENT = int(os.environ.get("ORACLE_DB_POOL_INCREMENT", 1))
|
||||
|
||||
|
||||
if VECTOR_DB == "oracle23ai":
|
||||
if not ORACLE_DB_USER or not ORACLE_DB_PASSWORD or not ORACLE_DB_DSN:
|
||||
raise ValueError(
|
||||
"Oracle23ai requires setting ORACLE_DB_USER, ORACLE_DB_PASSWORD, and ORACLE_DB_DSN."
|
||||
)
|
||||
if ORACLE_DB_USE_WALLET and (not ORACLE_WALLET_DIR or not ORACLE_WALLET_PASSWORD):
|
||||
raise ValueError(
|
||||
"Oracle23ai requires setting ORACLE_WALLET_DIR and ORACLE_WALLET_PASSWORD when using wallet authentication."
|
||||
)
|
||||
|
||||
log.info(f"VECTOR_DB: {VECTOR_DB}")
|
||||
|
||||
# S3 Vector
|
||||
S3_VECTOR_BUCKET_NAME = os.environ.get("S3_VECTOR_BUCKET_NAME", None)
|
||||
S3_VECTOR_REGION = os.environ.get("S3_VECTOR_REGION", None)
|
||||
|
||||
####################################
|
||||
# Information Retrieval (RAG)
|
||||
####################################
|
||||
|
|
@ -1994,10 +2123,16 @@ DATALAB_MARKER_API_KEY = PersistentConfig(
|
|||
os.environ.get("DATALAB_MARKER_API_KEY", ""),
|
||||
)
|
||||
|
||||
DATALAB_MARKER_LANGS = PersistentConfig(
|
||||
"DATALAB_MARKER_LANGS",
|
||||
"rag.datalab_marker_langs",
|
||||
os.environ.get("DATALAB_MARKER_LANGS", ""),
|
||||
DATALAB_MARKER_API_BASE_URL = PersistentConfig(
|
||||
"DATALAB_MARKER_API_BASE_URL",
|
||||
"rag.datalab_marker_api_base_url",
|
||||
os.environ.get("DATALAB_MARKER_API_BASE_URL", ""),
|
||||
)
|
||||
|
||||
DATALAB_MARKER_ADDITIONAL_CONFIG = PersistentConfig(
|
||||
"DATALAB_MARKER_ADDITIONAL_CONFIG",
|
||||
"rag.datalab_marker_additional_config",
|
||||
os.environ.get("DATALAB_MARKER_ADDITIONAL_CONFIG", ""),
|
||||
)
|
||||
|
||||
DATALAB_MARKER_USE_LLM = PersistentConfig(
|
||||
|
|
@ -2037,6 +2172,12 @@ DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = PersistentConfig(
|
|||
== "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_FORMAT_LINES = PersistentConfig(
|
||||
"DATALAB_MARKER_FORMAT_LINES",
|
||||
"rag.datalab_marker_format_lines",
|
||||
os.environ.get("DATALAB_MARKER_FORMAT_LINES", "false").lower() == "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_OUTPUT_FORMAT = PersistentConfig(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT",
|
||||
"rag.datalab_marker_output_format",
|
||||
|
|
@ -2486,6 +2627,14 @@ WEB_LOADER_ENGINE = PersistentConfig(
|
|||
os.environ.get("WEB_LOADER_ENGINE", ""),
|
||||
)
|
||||
|
||||
|
||||
WEB_LOADER_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
"WEB_LOADER_CONCURRENT_REQUESTS",
|
||||
"rag.web.loader.concurrent_requests",
|
||||
int(os.getenv("WEB_LOADER_CONCURRENT_REQUESTS", "10")),
|
||||
)
|
||||
|
||||
|
||||
ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
|
||||
"ENABLE_WEB_LOADER_SSL_VERIFICATION",
|
||||
"rag.web.loader.ssl_verification",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import sys
|
|||
import shutil
|
||||
from uuid import uuid4
|
||||
from pathlib import Path
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
|
|
@ -16,14 +17,17 @@ from open_webui.constants import ERROR_MESSAGES
|
|||
# Load .env file
|
||||
####################################
|
||||
|
||||
OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
|
||||
print(OPEN_WEBUI_DIR)
|
||||
# Use .resolve() to get the canonical path, removing any '..' or '.' components
|
||||
ENV_FILE_PATH = Path(__file__).resolve()
|
||||
|
||||
BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
|
||||
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
|
||||
# OPEN_WEBUI_DIR should be the directory where env.py resides (open_webui/)
|
||||
OPEN_WEBUI_DIR = ENV_FILE_PATH.parent
|
||||
|
||||
print(BACKEND_DIR)
|
||||
print(BASE_DIR)
|
||||
# BACKEND_DIR is the parent of OPEN_WEBUI_DIR (backend/)
|
||||
BACKEND_DIR = OPEN_WEBUI_DIR.parent
|
||||
|
||||
# BASE_DIR is the parent of BACKEND_DIR (open-webui-dev/)
|
||||
BASE_DIR = BACKEND_DIR.parent
|
||||
|
||||
try:
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
|
@ -276,9 +280,6 @@ if DATABASE_USER:
|
|||
DATABASE_CRED += f"{DATABASE_USER}"
|
||||
if DATABASE_PASSWORD:
|
||||
DATABASE_CRED += f":{DATABASE_PASSWORD}"
|
||||
if DATABASE_CRED:
|
||||
DATABASE_CRED += "@"
|
||||
|
||||
|
||||
DB_VARS = {
|
||||
"db_type": DATABASE_TYPE,
|
||||
|
|
@ -290,6 +291,9 @@ DB_VARS = {
|
|||
|
||||
if all(DB_VARS.values()):
|
||||
DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
|
||||
elif DATABASE_TYPE == "sqlite+sqlcipher" and not os.environ.get("DATABASE_URL"):
|
||||
# Handle SQLCipher with local file when DATABASE_URL wasn't explicitly set
|
||||
DATABASE_URL = f"sqlite+sqlcipher:///{DATA_DIR}/webui.db"
|
||||
|
||||
# Replace the postgres:// with postgresql://
|
||||
if "postgres://" in DATABASE_URL:
|
||||
|
|
@ -335,6 +339,21 @@ else:
|
|||
except Exception:
|
||||
DATABASE_POOL_RECYCLE = 3600
|
||||
|
||||
DATABASE_ENABLE_SQLITE_WAL = (
|
||||
os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true"
|
||||
)
|
||||
|
||||
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get(
|
||||
"DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL", None
|
||||
)
|
||||
if DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL is not None:
|
||||
try:
|
||||
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float(
|
||||
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
||||
)
|
||||
except Exception:
|
||||
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0
|
||||
|
||||
RESET_CONFIG_ON_START = (
|
||||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||
)
|
||||
|
|
@ -348,10 +367,22 @@ ENABLE_REALTIME_CHAT_SAVE = (
|
|||
####################################
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "")
|
||||
REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
|
||||
|
||||
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
|
||||
|
||||
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
|
||||
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
|
||||
|
||||
# Maximum number of retries for Redis operations when using Sentinel fail-over
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get("REDIS_SENTINEL_MAX_RETRY_COUNT", "2")
|
||||
try:
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT = int(REDIS_SENTINEL_MAX_RETRY_COUNT)
|
||||
if REDIS_SENTINEL_MAX_RETRY_COUNT < 1:
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
|
||||
except ValueError:
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
|
||||
|
||||
####################################
|
||||
# UVICORN WORKERS
|
||||
####################################
|
||||
|
|
@ -371,6 +402,10 @@ except ValueError:
|
|||
####################################
|
||||
|
||||
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
||||
ENABLE_SIGNUP_PASSWORD_CONFIRMATION = (
|
||||
os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true"
|
||||
)
|
||||
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
||||
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
||||
)
|
||||
|
|
@ -424,6 +459,41 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
|
|||
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# SCIM Configuration
|
||||
####################################
|
||||
|
||||
SCIM_ENABLED = os.environ.get("SCIM_ENABLED", "False").lower() == "true"
|
||||
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
|
||||
|
||||
####################################
|
||||
# LICENSE_KEY
|
||||
####################################
|
||||
|
||||
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
|
||||
|
||||
LICENSE_BLOB = None
|
||||
LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data")
|
||||
if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH):
|
||||
with open(LICENSE_BLOB_PATH, "rb") as f:
|
||||
LICENSE_BLOB = f.read()
|
||||
|
||||
LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "")
|
||||
|
||||
pk = None
|
||||
if LICENSE_PUBLIC_KEY:
|
||||
pk = serialization.load_pem_public_key(
|
||||
f"""
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
{LICENSE_PUBLIC_KEY}
|
||||
-----END PUBLIC KEY-----
|
||||
""".encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# MODELS
|
||||
####################################
|
||||
|
|
@ -438,6 +508,25 @@ else:
|
|||
MODELS_CACHE_TTL = 1
|
||||
|
||||
|
||||
####################################
|
||||
# CHAT
|
||||
####################################
|
||||
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
|
||||
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
|
||||
)
|
||||
|
||||
if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == "":
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
|
||||
else:
|
||||
try:
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int(
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE
|
||||
)
|
||||
except Exception:
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
|
||||
|
||||
|
||||
####################################
|
||||
# WEBSOCKET SUPPORT
|
||||
####################################
|
||||
|
|
@ -450,12 +539,21 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
|||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
|
||||
WEBSOCKET_REDIS_CLUSTER = (
|
||||
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
|
||||
)
|
||||
|
||||
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
|
||||
|
||||
try:
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT = int(websocket_redis_lock_timeout)
|
||||
except ValueError:
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
|
||||
|
||||
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
||||
|
||||
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
||||
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT == "":
|
||||
|
|
@ -597,13 +695,34 @@ AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
|||
####################################
|
||||
|
||||
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
|
||||
ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true"
|
||||
ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true"
|
||||
ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true"
|
||||
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
|
||||
)
|
||||
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||
"OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
|
||||
)
|
||||
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||
"OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
|
||||
)
|
||||
OTEL_EXPORTER_OTLP_INSECURE = (
|
||||
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
|
||||
)
|
||||
OTEL_METRICS_EXPORTER_OTLP_INSECURE = (
|
||||
os.environ.get(
|
||||
"OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
OTEL_LOGS_EXPORTER_OTLP_INSECURE = (
|
||||
os.environ.get(
|
||||
"OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
|
||||
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
|
||||
"OTEL_RESOURCE_ATTRIBUTES", ""
|
||||
|
|
@ -614,11 +733,30 @@ OTEL_TRACES_SAMPLER = os.environ.get(
|
|||
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
|
||||
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
|
||||
|
||||
OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get(
|
||||
"OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
|
||||
)
|
||||
OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get(
|
||||
"OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
|
||||
)
|
||||
OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get(
|
||||
"OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
|
||||
)
|
||||
OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get(
|
||||
"OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
|
||||
)
|
||||
|
||||
OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
|
||||
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
|
||||
).lower() # grpc or http
|
||||
|
||||
OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get(
|
||||
"OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
|
||||
).lower() # grpc or http
|
||||
|
||||
OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get(
|
||||
"OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
|
||||
).lower() # grpc or http
|
||||
|
||||
####################################
|
||||
# TOOLS/FUNCTIONS PIP OPTIONS
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from open_webui.utils.misc import (
|
|||
)
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
apply_system_prompt_to_body,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -253,9 +253,7 @@ async def generate_function_chat_completion(
|
|||
if params:
|
||||
system = params.pop("system", None)
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(
|
||||
system, form_data, metadata, user
|
||||
)
|
||||
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
|
@ -13,9 +14,10 @@ from open_webui.env import (
|
|||
DATABASE_POOL_RECYCLE,
|
||||
DATABASE_POOL_SIZE,
|
||||
DATABASE_POOL_TIMEOUT,
|
||||
DATABASE_ENABLE_SQLITE_WAL,
|
||||
)
|
||||
from peewee_migrate import Router
|
||||
from sqlalchemy import Dialect, create_engine, MetaData, types
|
||||
from sqlalchemy import Dialect, create_engine, MetaData, event, types
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy.pool import QueuePool, NullPool
|
||||
|
|
@ -79,10 +81,50 @@ handle_peewee_migration(DATABASE_URL)
|
|||
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||
if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||
|
||||
# Handle SQLCipher URLs
|
||||
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||
if not database_password or database_password.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
|
||||
# Extract database path from SQLCipher URL
|
||||
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
||||
if db_path.startswith("/"):
|
||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||
|
||||
# Create a custom creator function that uses sqlcipher3
|
||||
def create_sqlcipher_connection():
|
||||
import sqlcipher3
|
||||
|
||||
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
||||
conn.execute(f"PRAGMA key = '{database_password}'")
|
||||
return conn
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite://", # Dummy URL since we're using creator
|
||||
creator=create_sqlcipher_connection,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||
|
||||
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
if DATABASE_ENABLE_SQLITE_WAL:
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
else:
|
||||
cursor.execute("PRAGMA journal_mode=DELETE")
|
||||
cursor.close()
|
||||
|
||||
event.listen(engine, "connect", on_connect)
|
||||
else:
|
||||
if isinstance(DATABASE_POOL_SIZE, int):
|
||||
if DATABASE_POOL_SIZE > 0:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
|
@ -43,24 +44,47 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
|||
|
||||
|
||||
def register_connection(db_url):
|
||||
db = connect(db_url, unquote_user=True, unquote_password=True)
|
||||
if isinstance(db, PostgresqlDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
# Check if using SQLCipher protocol
|
||||
if db_url.startswith("sqlite+sqlcipher://"):
|
||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||
if not database_password or database_password.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
from playhouse.sqlcipher_ext import SqlCipherDatabase
|
||||
|
||||
# Parse the database path from SQLCipher URL
|
||||
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
||||
db_path = db_url.replace("sqlite+sqlcipher://", "")
|
||||
if db_path.startswith("/"):
|
||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||
|
||||
# Use Peewee's native SqlCipherDatabase with encryption
|
||||
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to PostgreSQL database")
|
||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url, unquote_user=True, unquote_password=True)
|
||||
|
||||
# Use our custom database class that supports reconnection
|
||||
db = ReconnectingPostgresqlDatabase(**connection)
|
||||
db.connect(reuse_if_open=True)
|
||||
elif isinstance(db, SqliteDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to SQLite database")
|
||||
else:
|
||||
raise ValueError("Unsupported database connection")
|
||||
# Standard database connection (existing logic)
|
||||
db = connect(db_url, unquote_user=True, unquote_password=True)
|
||||
if isinstance(db, PostgresqlDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to PostgreSQL database")
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url, unquote_user=True, unquote_password=True)
|
||||
|
||||
# Use our custom database class that supports reconnection
|
||||
db = ReconnectingPostgresqlDatabase(**connection)
|
||||
db.connect(reuse_if_open=True)
|
||||
elif isinstance(db, SqliteDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to SQLite database")
|
||||
else:
|
||||
raise ValueError("Unsupported database connection")
|
||||
return db
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ from open_webui.utils.logger import start_logger
|
|||
from open_webui.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
get_event_emitter,
|
||||
get_models_in_use,
|
||||
get_active_user_ids,
|
||||
)
|
||||
|
|
@ -85,6 +86,7 @@ from open_webui.routers import (
|
|||
tools,
|
||||
users,
|
||||
utils,
|
||||
scim,
|
||||
)
|
||||
|
||||
from open_webui.routers.retrieval import (
|
||||
|
|
@ -102,7 +104,6 @@ from open_webui.models.users import UserModel, Users
|
|||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.config import (
|
||||
LICENSE_KEY,
|
||||
# Ollama
|
||||
ENABLE_OLLAMA_API,
|
||||
OLLAMA_BASE_URLS,
|
||||
|
|
@ -185,6 +186,7 @@ from open_webui.config import (
|
|||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
WEB_LOADER_ENGINE,
|
||||
WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
WHISPER_MODEL,
|
||||
WHISPER_VAD_FILTER,
|
||||
WHISPER_LANGUAGE,
|
||||
|
|
@ -227,12 +229,14 @@ from open_webui.config import (
|
|||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
DATALAB_MARKER_API_KEY,
|
||||
DATALAB_MARKER_LANGS,
|
||||
DATALAB_MARKER_API_BASE_URL,
|
||||
DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||
DATALAB_MARKER_SKIP_CACHE,
|
||||
DATALAB_MARKER_FORCE_OCR,
|
||||
DATALAB_MARKER_PAGINATE,
|
||||
DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||
DATALAB_MARKER_FORMAT_LINES,
|
||||
DATALAB_MARKER_OUTPUT_FORMAT,
|
||||
DATALAB_MARKER_USE_LLM,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
|
|
@ -325,6 +329,7 @@ from open_webui.config import (
|
|||
ENABLE_MESSAGE_RATING,
|
||||
ENABLE_USER_WEBHOOKS,
|
||||
ENABLE_EVALUATION_ARENA_MODELS,
|
||||
BYPASS_ADMIN_ACCESS_CONTROL,
|
||||
USER_PERMISSIONS,
|
||||
DEFAULT_USER_ROLE,
|
||||
PENDING_USER_OVERLAY_CONTENT,
|
||||
|
|
@ -373,6 +378,7 @@ from open_webui.config import (
|
|||
RESPONSE_WATERMARK,
|
||||
# Admin
|
||||
ENABLE_ADMIN_CHAT_ACCESS,
|
||||
BYPASS_ADMIN_ACCESS_CONTROL,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
# Tasks
|
||||
TASK_MODEL,
|
||||
|
|
@ -395,10 +401,12 @@ from open_webui.config import (
|
|||
reset_config,
|
||||
)
|
||||
from open_webui.env import (
|
||||
LICENSE_KEY,
|
||||
AUDIT_EXCLUDED_PATHS,
|
||||
AUDIT_LOG_LEVEL,
|
||||
CHANGELOG,
|
||||
REDIS_URL,
|
||||
REDIS_CLUSTER,
|
||||
REDIS_KEY_PREFIX,
|
||||
REDIS_SENTINEL_HOSTS,
|
||||
REDIS_SENTINEL_PORT,
|
||||
|
|
@ -412,9 +420,13 @@ from open_webui.env import (
|
|||
WEBUI_SECRET_KEY,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
# SCIM
|
||||
SCIM_ENABLED,
|
||||
SCIM_TOKEN,
|
||||
ENABLE_COMPRESSION_MIDDLEWARE,
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
|
|
@ -455,6 +467,7 @@ from open_webui.utils.redis import get_redis_connection
|
|||
from open_webui.tasks import (
|
||||
redis_task_command_listener,
|
||||
list_task_ids_by_item_id,
|
||||
create_task,
|
||||
stop_task,
|
||||
list_tasks,
|
||||
) # Import from tasks.py
|
||||
|
|
@ -462,6 +475,9 @@ from open_webui.tasks import (
|
|||
from open_webui.utils.redis import get_sentinels_from_env
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
|
|
@ -524,6 +540,7 @@ async def lifespan(app: FastAPI):
|
|||
redis_sentinels=get_sentinels_from_env(
|
||||
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||
),
|
||||
redis_cluster=REDIS_CLUSTER,
|
||||
async_mode=True,
|
||||
)
|
||||
|
||||
|
|
@ -579,6 +596,7 @@ app.state.instance_id = None
|
|||
app.state.config = AppConfig(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
|
||||
redis_cluster=REDIS_CLUSTER,
|
||||
redis_key_prefix=REDIS_KEY_PREFIX,
|
||||
)
|
||||
app.state.redis = None
|
||||
|
|
@ -642,6 +660,15 @@ app.state.TOOL_SERVERS = []
|
|||
|
||||
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||
|
||||
########################################
|
||||
#
|
||||
# SCIM
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.SCIM_ENABLED = SCIM_ENABLED
|
||||
app.state.SCIM_TOKEN = SCIM_TOKEN
|
||||
|
||||
########################################
|
||||
#
|
||||
# MODELS
|
||||
|
|
@ -767,7 +794,8 @@ app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERI
|
|||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.DATALAB_MARKER_API_KEY = DATALAB_MARKER_API_KEY
|
||||
app.state.config.DATALAB_MARKER_LANGS = DATALAB_MARKER_LANGS
|
||||
app.state.config.DATALAB_MARKER_API_BASE_URL = DATALAB_MARKER_API_BASE_URL
|
||||
app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG = DATALAB_MARKER_ADDITIONAL_CONFIG
|
||||
app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE
|
||||
app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR
|
||||
app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE
|
||||
|
|
@ -775,6 +803,7 @@ app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTI
|
|||
app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||
)
|
||||
app.state.config.DATALAB_MARKER_FORMAT_LINES = DATALAB_MARKER_FORMAT_LINES
|
||||
app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM
|
||||
app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT
|
||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
|
||||
|
|
@ -829,7 +858,10 @@ app.state.config.WEB_SEARCH_ENGINE = WEB_SEARCH_ENGINE
|
|||
app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
app.state.config.WEB_SEARCH_RESULT_COUNT = WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
|
||||
app.state.config.WEB_LOADER_ENGINE = WEB_LOADER_ENGINE
|
||||
app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = WEB_LOADER_CONCURRENT_REQUESTS
|
||||
|
||||
app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
|
|
@ -892,14 +924,19 @@ try:
|
|||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
|
||||
app.state.rf = get_rf(
|
||||
app.state.config.RAG_RERANKING_ENGINE,
|
||||
app.state.config.RAG_RERANKING_MODEL,
|
||||
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
if (
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
app.state.rf = get_rf(
|
||||
app.state.config.RAG_RERANKING_ENGINE,
|
||||
app.state.config.RAG_RERANKING_MODEL,
|
||||
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
else:
|
||||
app.state.rf = None
|
||||
except Exception as e:
|
||||
log.error(f"Error updating models: {e}")
|
||||
pass
|
||||
|
|
@ -1211,6 +1248,10 @@ app.include_router(
|
|||
)
|
||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||
|
||||
# SCIM 2.0 API for identity management
|
||||
if SCIM_ENABLED:
|
||||
app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"])
|
||||
|
||||
|
||||
try:
|
||||
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
|
||||
|
|
@ -1233,6 +1274,7 @@ if audit_level != AuditLevel.NONE:
|
|||
|
||||
|
||||
@app.get("/api/models")
|
||||
@app.get("/api/v1/models") # Experimental: Compatibility with OpenAI API
|
||||
async def get_models(
|
||||
request: Request, refresh: bool = False, user=Depends(get_verified_user)
|
||||
):
|
||||
|
|
@ -1252,8 +1294,12 @@ async def get_models(
|
|||
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
if (
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
filtered_models.append(model)
|
||||
|
||||
|
|
@ -1288,15 +1334,21 @@ async def get_models(
|
|||
model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)}
|
||||
# Sort models by order list priority, with fallback for those not in the list
|
||||
models.sort(
|
||||
key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"])
|
||||
key=lambda model: (
|
||||
model_order_dict.get(model.get("id", ""), float("inf")),
|
||||
(model.get("name", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out models that the user does not have access to
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
if (
|
||||
user.role == "user"
|
||||
or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
) and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
models = get_filtered_models(models, user)
|
||||
|
||||
log.debug(
|
||||
f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}"
|
||||
f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}"
|
||||
)
|
||||
return {"data": models}
|
||||
|
||||
|
|
@ -1313,6 +1365,7 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
@app.post("/api/v1/embeddings") # Experimental: Compatibility with OpenAI API
|
||||
async def embeddings(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
|
@ -1339,6 +1392,7 @@ async def embeddings(
|
|||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
@app.post("/api/v1/chat/completions") # Experimental: Compatibility with OpenAI API
|
||||
async def chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
|
|
@ -1347,13 +1401,13 @@ async def chat_completion(
|
|||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
model_id = form_data.get("model", None)
|
||||
model_item = form_data.pop("model_item", {})
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if not model_item.get("direct", False):
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
|
||||
|
|
@ -1361,7 +1415,9 @@ async def chat_completion(
|
|||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and (
|
||||
user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL
|
||||
):
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
|
|
@ -1373,6 +1429,19 @@ async def chat_completion(
|
|||
request.state.direct = True
|
||||
request.state.model = model
|
||||
|
||||
model_info_params = (
|
||||
model_info.params.model_dump() if model_info and model_info.params else {}
|
||||
)
|
||||
|
||||
# Chat Params
|
||||
stream_delta_chunk_size = form_data.get("params", {}).get(
|
||||
"stream_delta_chunk_size"
|
||||
)
|
||||
|
||||
# Model Params
|
||||
if model_info_params.get("stream_delta_chunk_size"):
|
||||
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
||||
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
"chat_id": form_data.pop("chat_id", None),
|
||||
|
|
@ -1386,64 +1455,103 @@ async def chat_completion(
|
|||
"variables": form_data.get("variables", {}),
|
||||
"model": model,
|
||||
"direct": model_item.get("direct", False),
|
||||
**(
|
||||
{"function_calling": "native"}
|
||||
if form_data.get("params", {}).get("function_calling") == "native"
|
||||
or (
|
||||
model_info
|
||||
and model_info.params.model_dump().get("function_calling")
|
||||
== "native"
|
||||
)
|
||||
else {}
|
||||
),
|
||||
"params": {
|
||||
"stream_delta_chunk_size": stream_delta_chunk_size,
|
||||
"function_calling": (
|
||||
"native"
|
||||
if (
|
||||
form_data.get("params", {}).get("function_calling") == "native"
|
||||
or model_info_params.get("function_calling") == "native"
|
||||
)
|
||||
else "default"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
if metadata.get("chat_id") and (user and user.role != "admin"):
|
||||
if metadata["chat_id"] != "local":
|
||||
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
||||
if chat is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
request.state.metadata = metadata
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, user, metadata, model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.debug(f"Error processing chat payload: {e}")
|
||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||
# Update the chat message with the error
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"error": {"content": str(e)},
|
||||
},
|
||||
)
|
||||
|
||||
log.debug(f"Error processing chat metadata: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"Error in chat completion: {e}")
|
||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||
# Update the chat message with the error
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"error": {"content": str(e)},
|
||||
},
|
||||
async def process_chat(request, form_data, user, metadata, model):
|
||||
try:
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, user, metadata, model
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||
try:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"model": model_id,
|
||||
},
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
log.info("Chat processing was cancelled")
|
||||
try:
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
await event_emitter(
|
||||
{"type": "task-cancelled"},
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.debug(f"Error processing chat payload: {e}")
|
||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||
# Update the chat message with the error
|
||||
try:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"error": {"content": str(e)},
|
||||
},
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
if (
|
||||
metadata.get("session_id")
|
||||
and metadata.get("chat_id")
|
||||
and metadata.get("message_id")
|
||||
):
|
||||
# Asynchronous Chat Processing
|
||||
task_id, _ = await create_task(
|
||||
request.app.state.redis,
|
||||
process_chat(request, form_data, user, metadata, model),
|
||||
id=metadata["chat_id"],
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
else:
|
||||
return await process_chat(request, form_data, user, metadata, model)
|
||||
|
||||
|
||||
# Alias for chat_completion (Legacy)
|
||||
|
|
@ -1563,6 +1671,7 @@ async def get_app_config(request: Request):
|
|||
"features": {
|
||||
"auth": WEBUI_AUTH,
|
||||
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
||||
"enable_ldap": app.state.config.ENABLE_LDAP,
|
||||
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
||||
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
||||
|
|
@ -1641,19 +1750,32 @@ async def get_app_config(request: Request):
|
|||
else {}
|
||||
),
|
||||
}
|
||||
if user is not None
|
||||
if user is not None and (user.role in ["admin", "user"])
|
||||
else {
|
||||
**(
|
||||
{
|
||||
"ui": {
|
||||
"pending_user_overlay_title": app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"pending_user_overlay_content": app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
}
|
||||
}
|
||||
if user and user.role == "pending"
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{
|
||||
"metadata": {
|
||||
"login_footer": app.state.LICENSE_METADATA.get(
|
||||
"login_footer", ""
|
||||
)
|
||||
),
|
||||
"auth_logo_position": app.state.LICENSE_METADATA.get(
|
||||
"auth_logo_position", ""
|
||||
),
|
||||
}
|
||||
}
|
||||
if app.state.LICENSE_METADATA
|
||||
else {}
|
||||
)
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
|
|
@ -1765,11 +1887,10 @@ async def get_manifest_json():
|
|||
return {
|
||||
"name": app.state.WEBUI_NAME,
|
||||
"short_name": app.state.WEBUI_NAME,
|
||||
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
|
||||
"description": f"{app.state.WEBUI_NAME} is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
|
||||
"start_url": "/",
|
||||
"display": "standalone",
|
||||
"background_color": "#343541",
|
||||
"orientation": "any",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/static/logo.png",
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from logging.config import fileConfig
|
|||
|
||||
from alembic import context
|
||||
from open_webui.models.auths import Auth
|
||||
from open_webui.env import DATABASE_URL
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from open_webui.env import DATABASE_URL, DATABASE_PASSWORD
|
||||
from sqlalchemy import engine_from_config, pool, create_engine
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
@ -62,11 +62,38 @@ def run_migrations_online() -> None:
|
|||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
# Handle SQLCipher URLs
|
||||
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
|
||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
|
||||
# Extract database path from SQLCipher URL
|
||||
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
|
||||
if db_path.startswith("/"):
|
||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||
|
||||
# Create a custom creator function that uses sqlcipher3
|
||||
def create_sqlcipher_connection():
|
||||
import sqlcipher3
|
||||
|
||||
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
||||
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
|
||||
return conn
|
||||
|
||||
connectable = create_engine(
|
||||
"sqlite://", # Dummy URL since we're using creator
|
||||
creator=create_sqlcipher_connection,
|
||||
echo=False,
|
||||
)
|
||||
else:
|
||||
# Standard database connection (existing logic)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
"""Add indexes
|
||||
|
||||
Revision ID: 018012973d35
|
||||
Revises: d31026856c01
|
||||
Create Date: 2025-08-13 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "018012973d35"
|
||||
down_revision = "d31026856c01"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Chat table indexes
|
||||
op.create_index("folder_id_idx", "chat", ["folder_id"])
|
||||
op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"])
|
||||
op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"])
|
||||
op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"])
|
||||
op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"])
|
||||
|
||||
# Tag table index
|
||||
op.create_index("user_id_idx", "tag", ["user_id"])
|
||||
|
||||
# Function table index
|
||||
op.create_index("is_global_idx", "function", ["is_global"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Chat table indexes
|
||||
op.drop_index("folder_id_idx", table_name="chat")
|
||||
op.drop_index("user_id_pinned_idx", table_name="chat")
|
||||
op.drop_index("user_id_archived_idx", table_name="chat")
|
||||
op.drop_index("updated_at_user_id_idx", table_name="chat")
|
||||
op.drop_index("folder_id_user_id_idx", table_name="chat")
|
||||
|
||||
# Tag table index
|
||||
op.drop_index("user_id_idx", table_name="tag")
|
||||
|
||||
# Function table index
|
||||
|
||||
op.drop_index("is_global_idx", table_name="function")
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""update user table
|
||||
|
||||
Revision ID: 3af16a1c9fb6
|
||||
Revises: 018012973d35
|
||||
Create Date: 2025-08-21 02:07:18.078283
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3af16a1c9fb6"
|
||||
down_revision: Union[str, None] = "018012973d35"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True))
|
||||
op.add_column("user", sa.Column("bio", sa.Text(), nullable=True))
|
||||
op.add_column("user", sa.Column("gender", sa.Text(), nullable=True))
|
||||
op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "username")
|
||||
op.drop_column("user", "bio")
|
||||
op.drop_column("user", "gender")
|
||||
op.drop_column("user", "date_of_birth")
|
||||
|
|
@ -73,11 +73,6 @@ class ProfileImageUrlForm(BaseModel):
|
|||
profile_image_url: str
|
||||
|
||||
|
||||
class UpdateProfileForm(BaseModel):
|
||||
profile_image_url: str
|
||||
name: str
|
||||
|
||||
|
||||
class UpdatePasswordForm(BaseModel):
|
||||
password: str
|
||||
new_password: str
|
||||
|
|
|
|||
|
|
@ -6,10 +6,11 @@ from typing import Optional
|
|||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
from open_webui.models.folders import Folders
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, Index
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
|
|
@ -40,6 +41,20 @@ class Chat(Base):
|
|||
meta = Column(JSON, server_default="{}")
|
||||
folder_id = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Performance indexes for common queries
|
||||
# WHERE folder_id = ...
|
||||
Index("folder_id_idx", "folder_id"),
|
||||
# WHERE user_id = ... AND pinned = ...
|
||||
Index("user_id_pinned_idx", "user_id", "pinned"),
|
||||
# WHERE user_id = ... AND archived = ...
|
||||
Index("user_id_archived_idx", "user_id", "archived"),
|
||||
# WHERE user_id = ... ORDER BY updated_at DESC
|
||||
Index("updated_at_user_id_idx", "updated_at", "user_id"),
|
||||
# WHERE folder_id = ... AND user_id = ...
|
||||
Index("folder_id_user_id_idx", "folder_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
class ChatModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
@ -296,6 +311,9 @@ class ChatTable:
|
|||
"user_id": f"shared-{chat_id}",
|
||||
"title": chat.title,
|
||||
"chat": chat.chat,
|
||||
"meta": chat.meta,
|
||||
"pinned": chat.pinned,
|
||||
"folder_id": chat.folder_id,
|
||||
"created_at": chat.created_at,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
|
|
@ -327,7 +345,9 @@ class ChatTable:
|
|||
|
||||
shared_chat.title = chat.title
|
||||
shared_chat.chat = chat.chat
|
||||
|
||||
shared_chat.meta = chat.meta
|
||||
shared_chat.pinned = chat.pinned
|
||||
shared_chat.folder_id = chat.folder_id
|
||||
shared_chat.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(shared_chat)
|
||||
|
|
@ -612,8 +632,45 @@ class ChatTable:
|
|||
if word.startswith("tag:")
|
||||
]
|
||||
|
||||
# Extract folder names - handle spaces and case insensitivity
|
||||
folders = Folders.search_folders_by_names(
|
||||
user_id,
|
||||
[
|
||||
word.replace("folder:", "")
|
||||
for word in search_text_words
|
||||
if word.startswith("folder:")
|
||||
],
|
||||
)
|
||||
folder_ids = [folder.id for folder in folders]
|
||||
|
||||
is_pinned = None
|
||||
if "pinned:true" in search_text_words:
|
||||
is_pinned = True
|
||||
elif "pinned:false" in search_text_words:
|
||||
is_pinned = False
|
||||
|
||||
is_archived = None
|
||||
if "archived:true" in search_text_words:
|
||||
is_archived = True
|
||||
elif "archived:false" in search_text_words:
|
||||
is_archived = False
|
||||
|
||||
is_shared = None
|
||||
if "shared:true" in search_text_words:
|
||||
is_shared = True
|
||||
elif "shared:false" in search_text_words:
|
||||
is_shared = False
|
||||
|
||||
search_text_words = [
|
||||
word for word in search_text_words if not word.startswith("tag:")
|
||||
word
|
||||
for word in search_text_words
|
||||
if (
|
||||
not word.startswith("tag:")
|
||||
and not word.startswith("folder:")
|
||||
and not word.startswith("pinned:")
|
||||
and not word.startswith("archived:")
|
||||
and not word.startswith("shared:")
|
||||
)
|
||||
]
|
||||
|
||||
search_text = " ".join(search_text_words)
|
||||
|
|
@ -621,9 +678,23 @@ class ChatTable:
|
|||
with get_db() as db:
|
||||
query = db.query(Chat).filter(Chat.user_id == user_id)
|
||||
|
||||
if not include_archived:
|
||||
if is_archived is not None:
|
||||
query = query.filter(Chat.archived == is_archived)
|
||||
elif not include_archived:
|
||||
query = query.filter(Chat.archived == False)
|
||||
|
||||
if is_pinned is not None:
|
||||
query = query.filter(Chat.pinned == is_pinned)
|
||||
|
||||
if is_shared is not None:
|
||||
if is_shared:
|
||||
query = query.filter(Chat.share_id.isnot(None))
|
||||
else:
|
||||
query = query.filter(Chat.share_id.is_(None))
|
||||
|
||||
if folder_ids:
|
||||
query = query.filter(Chat.folder_id.in_(folder_ids))
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
# Check if the database dialect is either 'sqlite' or 'postgresql'
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
from open_webui.utils.access_control import get_permissions
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -63,7 +63,7 @@ class FolderForm(BaseModel):
|
|||
|
||||
class FolderTable:
|
||||
def insert_new_folder(
|
||||
self, user_id: str, name: str, parent_id: Optional[str] = None
|
||||
self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None
|
||||
) -> Optional[FolderModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
|
@ -71,7 +71,7 @@ class FolderTable:
|
|||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
**(form_data.model_dump(exclude_unset=True) or {}),
|
||||
"parent_id": parent_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
|
|
@ -106,7 +106,7 @@ class FolderTable:
|
|||
|
||||
def get_children_folders_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[FolderModel]:
|
||||
) -> Optional[list[FolderModel]]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folders = []
|
||||
|
|
@ -251,18 +251,15 @@ class FolderTable:
|
|||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def delete_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str, delete_chats=True
|
||||
) -> bool:
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]:
|
||||
try:
|
||||
folder_ids = []
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
if not folder:
|
||||
return False
|
||||
return folder_ids
|
||||
|
||||
if delete_chats:
|
||||
# Delete all chats in the folder
|
||||
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
|
||||
folder_ids.append(folder.id)
|
||||
|
||||
# Delete all children folders
|
||||
def delete_children(folder):
|
||||
|
|
@ -270,12 +267,9 @@ class FolderTable:
|
|||
folder.id, user_id
|
||||
)
|
||||
for folder_child in folder_children:
|
||||
if delete_chats:
|
||||
Chats.delete_chats_by_user_id_and_folder_id(
|
||||
user_id, folder_child.id
|
||||
)
|
||||
|
||||
delete_children(folder_child)
|
||||
folder_ids.append(folder_child.id)
|
||||
|
||||
folder = db.query(Folder).filter_by(id=folder_child.id).first()
|
||||
db.delete(folder)
|
||||
|
|
@ -284,10 +278,62 @@ class FolderTable:
|
|||
delete_children(folder)
|
||||
db.delete(folder)
|
||||
db.commit()
|
||||
return True
|
||||
return folder_ids
|
||||
except Exception as e:
|
||||
log.error(f"delete_folder: {e}")
|
||||
return False
|
||||
return []
|
||||
|
||||
def normalize_folder_name(self, name: str) -> str:
|
||||
# Replace _ and space with a single space, lower case, collapse multiple spaces
|
||||
name = re.sub(r"[\s_]+", " ", name)
|
||||
return name.strip().lower()
|
||||
|
||||
def search_folders_by_names(
|
||||
self, user_id: str, queries: list[str]
|
||||
) -> list[FolderModel]:
|
||||
"""
|
||||
Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive.
|
||||
"""
|
||||
normalized_queries = [self.normalize_folder_name(q) for q in queries]
|
||||
if not normalized_queries:
|
||||
return []
|
||||
|
||||
results = {}
|
||||
with get_db() as db:
|
||||
folders = db.query(Folder).filter_by(user_id=user_id).all()
|
||||
for folder in folders:
|
||||
if self.normalize_folder_name(folder.name) in normalized_queries:
|
||||
results[folder.id] = FolderModel.model_validate(folder)
|
||||
|
||||
# get children folders
|
||||
children = self.get_children_folders_by_id_and_user_id(
|
||||
folder.id, user_id
|
||||
)
|
||||
for child in children:
|
||||
results[child.id] = child
|
||||
|
||||
# Return the results as a list
|
||||
if not results:
|
||||
return []
|
||||
else:
|
||||
results = list(results.values())
|
||||
return results
|
||||
|
||||
def search_folders_by_name_contains(
|
||||
self, user_id: str, query: str
|
||||
) -> list[FolderModel]:
|
||||
"""
|
||||
Partial match: normalized name contains (as substring) the normalized query.
|
||||
"""
|
||||
normalized_query = self.normalize_folder_name(query)
|
||||
results = []
|
||||
with get_db() as db:
|
||||
folders = db.query(Folder).filter_by(user_id=user_id).all()
|
||||
for folder in folders:
|
||||
norm_name = self.normalize_folder_name(folder.name)
|
||||
if normalized_query in norm_name:
|
||||
results.append(FolderModel.model_validate(folder))
|
||||
return results
|
||||
|
||||
|
||||
Folders = FolderTable()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from open_webui.internal.db import Base, JSONField, get_db
|
|||
from open_webui.models.users import Users
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
|
@ -31,6 +31,8 @@ class Function(Base):
|
|||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
__table_args__ = (Index("is_global_idx", "is_global"),)
|
||||
|
||||
|
||||
class FunctionMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
|
|
@ -250,9 +252,7 @@ class FunctionsTable:
|
|||
|
||||
return user_settings["functions"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user id {user_id}: {e}"
|
||||
)
|
||||
log.exception(f"Error getting user values by id {id} and user id {user_id}")
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
|
|
|
|||
|
|
@ -83,10 +83,14 @@ class GroupForm(BaseModel):
|
|||
permissions: Optional[dict] = None
|
||||
|
||||
|
||||
class GroupUpdateForm(GroupForm):
|
||||
class UserIdsForm(BaseModel):
|
||||
user_ids: Optional[list[str]] = None
|
||||
|
||||
|
||||
class GroupUpdateForm(GroupForm, UserIdsForm):
|
||||
pass
|
||||
|
||||
|
||||
class GroupTable:
|
||||
def insert_new_group(
|
||||
self, user_id: str, form_data: GroupForm
|
||||
|
|
@ -275,5 +279,53 @@ class GroupTable:
|
|||
log.exception(e)
|
||||
return False
|
||||
|
||||
def add_users_to_group(
|
||||
self, id: str, user_ids: Optional[list[str]] = None
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
if not group:
|
||||
return None
|
||||
|
||||
if not group.user_ids:
|
||||
group.user_ids = []
|
||||
|
||||
for user_id in user_ids:
|
||||
if user_id not in group.user_ids:
|
||||
group.user_ids.append(user_id)
|
||||
|
||||
group.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
return GroupModel.model_validate(group)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def remove_users_from_group(
|
||||
self, id: str, user_ids: Optional[list[str]] = None
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
if not group:
|
||||
return None
|
||||
|
||||
if not group.user_ids:
|
||||
return GroupModel.model_validate(group)
|
||||
|
||||
for user_id in user_ids:
|
||||
if user_id in group.user_ids:
|
||||
group.user_ids.remove(user_id)
|
||||
|
||||
group.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
return GroupModel.model_validate(group)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
|
||||
Groups = GroupTable()
|
||||
|
|
|
|||
|
|
@ -71,9 +71,13 @@ class MemoriesTable:
|
|||
) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id, user_id=user_id).update(
|
||||
{"content": content, "updated_at": int(time.time())}
|
||||
)
|
||||
memory = db.get(Memory, id)
|
||||
if not memory or memory.user_id != user_id:
|
||||
return None
|
||||
|
||||
memory.content = content
|
||||
memory.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
return self.get_memory_by_id(id)
|
||||
except Exception:
|
||||
|
|
@ -127,7 +131,12 @@ class MemoriesTable:
|
|||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
|
||||
memory = db.get(Memory, id)
|
||||
if not memory or memory.user_id != user_id:
|
||||
return None
|
||||
|
||||
# Delete the memory
|
||||
db.delete(memory)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -269,5 +269,49 @@ class ModelsTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Get existing models
|
||||
existing_models = db.query(Model).all()
|
||||
existing_ids = {model.id for model in existing_models}
|
||||
|
||||
# Prepare a set of new model IDs
|
||||
new_model_ids = {model.id for model in models}
|
||||
|
||||
# Update or insert models
|
||||
for model in models:
|
||||
if model.id in existing_ids:
|
||||
db.query(Model).filter_by(id=model.id).update(
|
||||
{
|
||||
**model.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_model = Model(
|
||||
**{
|
||||
**model.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(new_model)
|
||||
|
||||
# Remove models that are no longer present
|
||||
for model in existing_models:
|
||||
if model.id not in new_model_ids:
|
||||
db.delete(model)
|
||||
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
ModelModel.model_validate(model) for model in db.query(Model).all()
|
||||
]
|
||||
except Exception as e:
|
||||
log.exception(f"Error syncing models for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
Models = ModelsTable()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from open_webui.internal.db import Base, get_db
|
|||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
|
||||
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
|
@ -24,6 +24,11 @@ class Tag(Base):
|
|||
user_id = Column(String)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),
|
||||
Index("user_id_idx", "user_id"),
|
||||
)
|
||||
|
||||
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
|
||||
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
|
||||
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class ToolsTable:
|
|||
tool = db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting tool valves by id {id}: {e}")
|
||||
log.exception(f"Error getting tool valves by id {id}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
|
|
|
|||
|
|
@ -4,14 +4,17 @@ from typing import Optional
|
|||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
|
||||
|
||||
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.utils.misc import throttle
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Column, String, Text, Date
|
||||
from sqlalchemy import or_
|
||||
|
||||
import datetime
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
|
|
@ -23,20 +26,28 @@ class User(Base):
|
|||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String)
|
||||
|
||||
email = Column(String)
|
||||
username = Column(String(50), nullable=True)
|
||||
|
||||
role = Column(String)
|
||||
profile_image_url = Column(Text)
|
||||
|
||||
last_active_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
bio = Column(Text, nullable=True)
|
||||
gender = Column(Text, nullable=True)
|
||||
date_of_birth = Column(Date, nullable=True)
|
||||
|
||||
info = Column(JSONField, nullable=True)
|
||||
settings = Column(JSONField, nullable=True)
|
||||
|
||||
api_key = Column(String, nullable=True, unique=True)
|
||||
settings = Column(JSONField, nullable=True)
|
||||
info = Column(JSONField, nullable=True)
|
||||
|
||||
oauth_sub = Column(Text, unique=True)
|
||||
|
||||
last_active_at = Column(BigInteger)
|
||||
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
|
|
@ -47,20 +58,27 @@ class UserSettings(BaseModel):
|
|||
class UserModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
email: str
|
||||
username: Optional[str] = None
|
||||
|
||||
role: str = "pending"
|
||||
profile_image_url: str
|
||||
|
||||
bio: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
|
||||
info: Optional[dict] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
|
||||
api_key: Optional[str] = None
|
||||
oauth_sub: Optional[str] = None
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
api_key: Optional[str] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
info: Optional[dict] = None
|
||||
|
||||
oauth_sub: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
|
|
@ -69,11 +87,31 @@ class UserModel(BaseModel):
|
|||
####################
|
||||
|
||||
|
||||
class UpdateProfileForm(BaseModel):
|
||||
profile_image_url: str
|
||||
name: str
|
||||
bio: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: list[UserModel]
|
||||
total: int
|
||||
|
||||
|
||||
class UserInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
|
||||
|
||||
class UserInfoListResponse(BaseModel):
|
||||
users: list[UserInfoResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
|
@ -246,6 +284,10 @@ class UsersTable:
|
|||
with get_db() as db:
|
||||
return db.query(User).count()
|
||||
|
||||
def has_users(self) -> bool:
|
||||
with get_db() as db:
|
||||
return db.query(db.query(User).exists()).scalar()
|
||||
|
||||
def get_first_user(self) -> UserModel:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
|
@ -295,6 +337,7 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
|
@ -330,7 +373,8 @@ class UsersTable:
|
|||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
# return UserModel(**user.dict())
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
|
|
|
|||
|
|
@ -15,24 +15,28 @@ class DatalabMarkerLoader:
|
|||
self,
|
||||
file_path: str,
|
||||
api_key: str,
|
||||
langs: Optional[str] = None,
|
||||
api_base_url: str,
|
||||
additional_config: Optional[str] = None,
|
||||
use_llm: bool = False,
|
||||
skip_cache: bool = False,
|
||||
force_ocr: bool = False,
|
||||
paginate: bool = False,
|
||||
strip_existing_ocr: bool = False,
|
||||
disable_image_extraction: bool = False,
|
||||
format_lines: bool = False,
|
||||
output_format: str = None,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.api_key = api_key
|
||||
self.langs = langs
|
||||
self.api_base_url = api_base_url
|
||||
self.additional_config = additional_config
|
||||
self.use_llm = use_llm
|
||||
self.skip_cache = skip_cache
|
||||
self.force_ocr = force_ocr
|
||||
self.paginate = paginate
|
||||
self.strip_existing_ocr = strip_existing_ocr
|
||||
self.disable_image_extraction = disable_image_extraction
|
||||
self.format_lines = format_lines
|
||||
self.output_format = output_format
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
|
|
@ -60,7 +64,7 @@ class DatalabMarkerLoader:
|
|||
return mime_map.get(ext, "application/octet-stream")
|
||||
|
||||
def check_marker_request_status(self, request_id: str) -> dict:
|
||||
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
|
||||
url = f"{self.api_base_url}/marker/{request_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
|
|
@ -81,22 +85,24 @@ class DatalabMarkerLoader:
|
|||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
url = "https://www.datalab.to/api/v1/marker"
|
||||
filename = os.path.basename(self.file_path)
|
||||
mime_type = self._get_mime_type(filename)
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
|
||||
form_data = {
|
||||
"langs": self.langs,
|
||||
"use_llm": str(self.use_llm).lower(),
|
||||
"skip_cache": str(self.skip_cache).lower(),
|
||||
"force_ocr": str(self.force_ocr).lower(),
|
||||
"paginate": str(self.paginate).lower(),
|
||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
||||
"format_lines": str(self.format_lines).lower(),
|
||||
"output_format": self.output_format,
|
||||
}
|
||||
|
||||
if self.additional_config and self.additional_config.strip():
|
||||
form_data["additional_config"] = self.additional_config
|
||||
|
||||
log.info(
|
||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||
)
|
||||
|
|
@ -105,7 +111,10 @@ class DatalabMarkerLoader:
|
|||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (filename, f, mime_type)}
|
||||
response = requests.post(
|
||||
url, data=form_data, files=files, headers=headers
|
||||
f"{self.api_base_url}/marker",
|
||||
data=form_data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
|
@ -133,74 +142,92 @@ class DatalabMarkerLoader:
|
|||
|
||||
check_url = result.get("request_check_url")
|
||||
request_id = result.get("request_id")
|
||||
if not check_url:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
|
||||
)
|
||||
|
||||
for _ in range(300): # Up to 10 minutes
|
||||
time.sleep(2)
|
||||
try:
|
||||
poll_response = requests.get(check_url, headers=headers)
|
||||
poll_response.raise_for_status()
|
||||
poll_result = poll_response.json()
|
||||
except (requests.HTTPError, ValueError) as e:
|
||||
raw_body = poll_response.text
|
||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
||||
)
|
||||
|
||||
status_val = poll_result.get("status")
|
||||
success_val = poll_result.get("success")
|
||||
|
||||
if status_val == "complete":
|
||||
summary = {
|
||||
k: poll_result.get(k)
|
||||
for k in (
|
||||
"status",
|
||||
"output_format",
|
||||
"success",
|
||||
"error",
|
||||
"page_count",
|
||||
"total_cost",
|
||||
# Check if this is a direct response (self-hosted) or polling response (DataLab)
|
||||
if check_url:
|
||||
# DataLab polling pattern
|
||||
for _ in range(300): # Up to 10 minutes
|
||||
time.sleep(2)
|
||||
try:
|
||||
poll_response = requests.get(check_url, headers=headers)
|
||||
poll_response.raise_for_status()
|
||||
poll_result = poll_response.json()
|
||||
except (requests.HTTPError, ValueError) as e:
|
||||
raw_body = poll_response.text
|
||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
||||
)
|
||||
}
|
||||
log.info(
|
||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
||||
)
|
||||
break
|
||||
|
||||
if status_val == "failed" or success_val is False:
|
||||
log.error(
|
||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
||||
)
|
||||
error_msg = (
|
||||
poll_result.get("error")
|
||||
or "Marker returned failure without error message"
|
||||
status_val = poll_result.get("status")
|
||||
success_val = poll_result.get("success")
|
||||
|
||||
if status_val == "complete":
|
||||
summary = {
|
||||
k: poll_result.get(k)
|
||||
for k in (
|
||||
"status",
|
||||
"output_format",
|
||||
"success",
|
||||
"error",
|
||||
"page_count",
|
||||
"total_cost",
|
||||
)
|
||||
}
|
||||
log.info(
|
||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
||||
)
|
||||
break
|
||||
|
||||
if status_val == "failed" or success_val is False:
|
||||
log.error(
|
||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
||||
)
|
||||
error_msg = (
|
||||
poll_result.get("error")
|
||||
or "Marker returned failure without error message"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Marker processing failed: {error_msg}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Marker processing timed out",
|
||||
)
|
||||
|
||||
if not poll_result.get("success", False):
|
||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Marker processing failed: {error_msg}",
|
||||
detail=f"Final processing failed: {error_msg}",
|
||||
)
|
||||
|
||||
# DataLab format - content in format-specific fields
|
||||
content_key = self.output_format.lower()
|
||||
raw_content = poll_result.get(content_key)
|
||||
final_result = poll_result
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
|
||||
)
|
||||
# Self-hosted direct response - content in "output" field
|
||||
if "output" in result:
|
||||
log.info("Self-hosted Marker returned direct response without polling")
|
||||
raw_content = result.get("output")
|
||||
final_result = result
|
||||
else:
|
||||
available_fields = (
|
||||
list(result.keys())
|
||||
if isinstance(result, dict)
|
||||
else "non-dict response"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
|
||||
)
|
||||
|
||||
if not poll_result.get("success", False):
|
||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Final processing failed: {error_msg}",
|
||||
)
|
||||
|
||||
content_key = self.output_format.lower()
|
||||
raw_content = poll_result.get(content_key)
|
||||
|
||||
if content_key == "json":
|
||||
if self.output_format.lower() == "json":
|
||||
full_text = json.dumps(raw_content, indent=2)
|
||||
elif content_key in {"markdown", "html"}:
|
||||
elif self.output_format.lower() in {"markdown", "html"}:
|
||||
full_text = str(raw_content).strip()
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -211,14 +238,14 @@ class DatalabMarkerLoader:
|
|||
if not full_text:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Datalab Marker returned empty content",
|
||||
detail="Marker returned empty content",
|
||||
)
|
||||
|
||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
||||
os.makedirs(marker_output_dir, exist_ok=True)
|
||||
|
||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
||||
file_ext = file_ext_map.get(content_key, "txt")
|
||||
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
|
||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
||||
output_path = os.path.join(marker_output_dir, output_filename)
|
||||
|
||||
|
|
@ -231,13 +258,13 @@ class DatalabMarkerLoader:
|
|||
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"output_format": poll_result.get("output_format", self.output_format),
|
||||
"page_count": poll_result.get("page_count", 0),
|
||||
"output_format": final_result.get("output_format", self.output_format),
|
||||
"page_count": final_result.get("page_count", 0),
|
||||
"processed_with_llm": self.use_llm,
|
||||
"request_id": request_id or "",
|
||||
}
|
||||
|
||||
images = poll_result.get("images", {})
|
||||
images = final_result.get("images", {})
|
||||
if images:
|
||||
metadata["image_count"] = len(images)
|
||||
metadata["images"] = json.dumps(list(images.keys()))
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class DoclingLoader:
|
|||
if lang.strip()
|
||||
]
|
||||
|
||||
endpoint = f"{self.url}/v1alpha/convert/file"
|
||||
endpoint = f"{self.url}/v1/convert/file"
|
||||
r = requests.post(endpoint, files=files, data=params)
|
||||
|
||||
if r.ok:
|
||||
|
|
@ -281,10 +281,15 @@ class Loader:
|
|||
"tiff",
|
||||
]
|
||||
):
|
||||
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
|
||||
if not api_base_url or api_base_url.strip() == "":
|
||||
api_base_url = "https://www.datalab.to/api/v1"
|
||||
|
||||
loader = DatalabMarkerLoader(
|
||||
file_path=file_path,
|
||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
||||
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
|
||||
api_base_url=api_base_url,
|
||||
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
|
||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
||||
|
|
@ -295,6 +300,7 @@ class Loader:
|
|||
disable_image_extraction=self.kwargs.get(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
||||
),
|
||||
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
|
||||
output_format=self.kwargs.get(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
||||
),
|
||||
|
|
|
|||
|
|
@ -124,12 +124,14 @@ def query_doc_with_hybrid_search(
|
|||
hybrid_bm25_weight: float,
|
||||
) -> dict:
|
||||
try:
|
||||
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
texts=collection_result.documents[0],
|
||||
metadatas=collection_result.metadatas[0],
|
||||
)
|
||||
bm25_retriever.k = k
|
||||
# BM_25 required only if weight is greater than 0
|
||||
if hybrid_bm25_weight > 0:
|
||||
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
texts=collection_result.documents[0],
|
||||
metadatas=collection_result.metadatas[0],
|
||||
)
|
||||
bm25_retriever.k = k
|
||||
|
||||
vector_search_retriever = VectorSearchRetriever(
|
||||
collection_name=collection_name,
|
||||
|
|
@ -337,18 +339,22 @@ def query_collection_with_hybrid_search(
|
|||
# Fetch collection data once per collection sequentially
|
||||
# Avoid fetching the same data multiple times later
|
||||
collection_results = {}
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
log.debug(
|
||||
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
|
||||
)
|
||||
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
|
||||
collection_name=collection_name
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to fetch collection {collection_name}: {e}")
|
||||
collection_results[collection_name] = None
|
||||
|
||||
# Only retrieve entire collection if bm_25 calculation is required
|
||||
if hybrid_bm25_weight > 0:
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
log.debug(
|
||||
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
|
||||
)
|
||||
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
|
||||
collection_name=collection_name
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to fetch collection {collection_name}: {e}")
|
||||
collection_results[collection_name] = None
|
||||
else:
|
||||
for collection_name in collection_names:
|
||||
collection_results[collection_name] = []
|
||||
log.info(
|
||||
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
|
||||
)
|
||||
|
|
@ -508,7 +514,11 @@ def get_sources_from_items(
|
|||
# Note Attached
|
||||
note = Notes.get_note_by_id(item.get("id"))
|
||||
|
||||
if user.role == "admin" or has_access(user.id, "read", note.access_control):
|
||||
if note and (
|
||||
user.role == "admin"
|
||||
or note.user_id == user.id
|
||||
or has_access(user.id, "read", note.access_control)
|
||||
):
|
||||
# User has access to the note
|
||||
query_result = {
|
||||
"documents": [[note.data.get("content", {}).get("md", "")]],
|
||||
|
|
@ -611,6 +621,9 @@ def get_sources_from_items(
|
|||
elif item.get("collection_name"):
|
||||
# Direct Collection Name
|
||||
collection_names.append(item["collection_name"])
|
||||
elif item.get("collection_names"):
|
||||
# Collection Names List
|
||||
collection_names.extend(item["collection_names"])
|
||||
|
||||
# If query_result is None
|
||||
# Fallback to collection names and vector search the collections
|
||||
|
|
@ -939,6 +952,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
|||
) -> Sequence[Document]:
|
||||
reranking = self.reranking_function is not None
|
||||
|
||||
scores = None
|
||||
if reranking:
|
||||
scores = self.reranking_function(
|
||||
[(query, doc.page_content) for doc in documents]
|
||||
|
|
@ -952,22 +966,31 @@ class RerankCompressor(BaseDocumentCompressor):
|
|||
)
|
||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||
|
||||
docs_with_scores = list(
|
||||
zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
|
||||
)
|
||||
if self.r_score:
|
||||
docs_with_scores = [
|
||||
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
||||
]
|
||||
|
||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||
final_results = []
|
||||
for doc, doc_score in result[: self.top_n]:
|
||||
metadata = doc.metadata
|
||||
metadata["score"] = doc_score
|
||||
doc = Document(
|
||||
page_content=doc.page_content,
|
||||
metadata=metadata,
|
||||
if scores is not None:
|
||||
docs_with_scores = list(
|
||||
zip(
|
||||
documents,
|
||||
scores.tolist() if not isinstance(scores, list) else scores,
|
||||
)
|
||||
)
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
if self.r_score:
|
||||
docs_with_scores = [
|
||||
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
||||
]
|
||||
|
||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||
final_results = []
|
||||
for doc, doc_score in result[: self.top_n]:
|
||||
metadata = doc.metadata
|
||||
metadata["score"] = doc_score
|
||||
doc = Document(
|
||||
page_content=doc.page_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
else:
|
||||
log.warning(
|
||||
"No valid scores found, check your reranking function. Returning original documents."
|
||||
)
|
||||
return documents
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from open_webui.retrieval.vector.main import (
|
|||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
|
|
@ -144,7 +146,7 @@ class ChromaClient(VectorDBBase):
|
|||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [item["metadata"] for item in items]
|
||||
metadatas = [stringify_metadata(item["metadata"]) for item in items]
|
||||
|
||||
for batch in create_batches(
|
||||
api=self.client,
|
||||
|
|
@ -164,7 +166,7 @@ class ChromaClient(VectorDBBase):
|
|||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [item["metadata"] for item in items]
|
||||
metadatas = [stringify_metadata(item["metadata"]) for item in items]
|
||||
|
||||
collection.upsert(
|
||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from elasticsearch import Elasticsearch, BadRequestError
|
|||
from typing import Optional
|
||||
import ssl
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
|
||||
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
|
|
@ -243,7 +245,7 @@ class ElasticsearchClient(VectorDBBase):
|
|||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
"metadata": stringify_metadata(item["metadata"]),
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
|
|
@ -264,7 +266,7 @@ class ElasticsearchClient(VectorDBBase):
|
|||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
"metadata": stringify_metadata(item["metadata"]),
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
from pymilvus import connections, Collection
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
|
|
@ -186,6 +190,8 @@ class MilvusClient(VectorDBBase):
|
|||
return self._result_to_search_result(result)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
||||
|
||||
# Construct the filter string for querying
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.has_collection(collection_name):
|
||||
|
|
@ -199,72 +205,36 @@ class MilvusClient(VectorDBBase):
|
|||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
max_limit = 16383 # The maximum number of records per request
|
||||
all_results = []
|
||||
if limit is None:
|
||||
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
|
||||
# Let's set a practical high number if "all" is intended, or handle true pagination.
|
||||
# For now, if limit is None, we'll fetch in batches up to a very large number.
|
||||
# This part could be refined based on expected use cases for "get all".
|
||||
# For this function signature, None implies "as many as possible" up to Milvus limits.
|
||||
limit = (
|
||||
16384 * 10
|
||||
) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
|
||||
log.info(
|
||||
f"Limit not specified for query, fetching up to {limit} results in batches."
|
||||
)
|
||||
|
||||
# Initialize offset and remaining to handle pagination
|
||||
offset = 0
|
||||
remaining = limit
|
||||
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
||||
collection.load()
|
||||
all_results = []
|
||||
|
||||
try:
|
||||
log.info(
|
||||
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
|
||||
)
|
||||
# Loop until there are no more items to fetch or the desired limit is reached
|
||||
while remaining > 0:
|
||||
current_fetch = min(
|
||||
max_limit, remaining if isinstance(remaining, int) else max_limit
|
||||
)
|
||||
log.debug(
|
||||
f"Querying with offset: {offset}, current_fetch: {current_fetch}"
|
||||
)
|
||||
|
||||
results = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
output_fields=[
|
||||
"id",
|
||||
"data",
|
||||
"metadata",
|
||||
], # Explicitly list needed fields. Vector not usually needed in query.
|
||||
limit=current_fetch,
|
||||
offset=offset,
|
||||
)
|
||||
iterator = collection.query_iterator(
|
||||
filter=filter_string,
|
||||
output_fields=[
|
||||
"id",
|
||||
"data",
|
||||
"metadata",
|
||||
],
|
||||
limit=limit, # Pass the limit directly; None means no limit.
|
||||
)
|
||||
|
||||
if not results:
|
||||
log.debug("No more results from query.")
|
||||
break
|
||||
|
||||
all_results.extend(results)
|
||||
results_count = len(results)
|
||||
log.debug(f"Fetched {results_count} results in this batch.")
|
||||
|
||||
if isinstance(remaining, int):
|
||||
remaining -= results_count
|
||||
|
||||
offset += results_count
|
||||
|
||||
# Break the loop if the results returned are less than the requested fetch count (means end of data)
|
||||
if results_count < current_fetch:
|
||||
log.debug(
|
||||
"Fetched less than requested, assuming end of results for this query."
|
||||
)
|
||||
while True:
|
||||
result = iterator.next()
|
||||
if not result:
|
||||
iterator.close()
|
||||
break
|
||||
all_results += result
|
||||
|
||||
log.info(f"Total results from query: {len(all_results)}")
|
||||
return self._result_to_get_result([all_results])
|
||||
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
|
||||
|
|
@ -311,7 +281,7 @@ class MilvusClient(VectorDBBase):
|
|||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
"metadata": stringify_metadata(item["metadata"]),
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
|
|
@ -347,7 +317,7 @@ class MilvusClient(VectorDBBase):
|
|||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
"metadata": stringify_metadata(item["metadata"]),
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from opensearchpy import OpenSearch
|
|||
from opensearchpy.helpers import bulk
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
|
|
@ -200,7 +201,7 @@ class OpenSearchClient(VectorDBBase):
|
|||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
"metadata": stringify_metadata(item["metadata"]),
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
|
|
@ -222,7 +223,7 @@ class OpenSearchClient(VectorDBBase):
|
|||
"doc": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
"metadata": stringify_metadata(item["metadata"]),
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
}
|
||||
|
|
|
|||
943
backend/open_webui/retrieval/vector/dbs/oracle23ai.py
Normal file
943
backend/open_webui/retrieval/vector/dbs/oracle23ai.py
Normal file
|
|
@ -0,0 +1,943 @@
|
|||
"""
|
||||
Oracle 23ai Vector Database Client - Fixed Version
|
||||
|
||||
# .env
|
||||
VECTOR_DB = "oracle23ai"
|
||||
|
||||
## DBCS or oracle 23ai free
|
||||
ORACLE_DB_USE_WALLET = false
|
||||
ORACLE_DB_USER = "DEMOUSER"
|
||||
ORACLE_DB_PASSWORD = "Welcome123456"
|
||||
ORACLE_DB_DSN = "localhost:1521/FREEPDB1"
|
||||
|
||||
## ADW or ATP
|
||||
# ORACLE_DB_USE_WALLET = true
|
||||
# ORACLE_DB_USER = "DEMOUSER"
|
||||
# ORACLE_DB_PASSWORD = "Welcome123456"
|
||||
# ORACLE_DB_DSN = "medium"
|
||||
# ORACLE_DB_DSN = "(description= (retry_count=3)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=xx.oraclecloud.com))(connect_data=(service_name=yy.adb.oraclecloud.com))(security=(ssl_server_dn_match=no)))"
|
||||
# ORACLE_WALLET_DIR = "/home/opc/adb_wallet"
|
||||
# ORACLE_WALLET_PASSWORD = "Welcome1"
|
||||
|
||||
ORACLE_VECTOR_LENGTH = 768
|
||||
|
||||
ORACLE_DB_POOL_MIN = 2
|
||||
ORACLE_DB_POOL_MAX = 10
|
||||
ORACLE_DB_POOL_INCREMENT = 1
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from decimal import Decimal
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
import array
|
||||
import oracledb
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
|
||||
from open_webui.config import (
|
||||
ORACLE_DB_USE_WALLET,
|
||||
ORACLE_DB_USER,
|
||||
ORACLE_DB_PASSWORD,
|
||||
ORACLE_DB_DSN,
|
||||
ORACLE_WALLET_DIR,
|
||||
ORACLE_WALLET_PASSWORD,
|
||||
ORACLE_VECTOR_LENGTH,
|
||||
ORACLE_DB_POOL_MIN,
|
||||
ORACLE_DB_POOL_MAX,
|
||||
ORACLE_DB_POOL_INCREMENT,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class Oracle23aiClient(VectorDBBase):
|
||||
"""
|
||||
Oracle Vector Database Client for vector similarity search using Oracle Database 23ai.
|
||||
|
||||
This client provides an interface to store, retrieve, and search vector embeddings
|
||||
in an Oracle database. It uses connection pooling for efficient database access
|
||||
and supports vector similarity search operations.
|
||||
|
||||
Attributes:
|
||||
pool: Connection pool for Oracle database connections
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the Oracle23aiClient with a connection pool.
|
||||
|
||||
Creates a connection pool with configurable min/max connections, initializes
|
||||
the database schema if needed, and sets up necessary tables and indexes.
|
||||
|
||||
Raises:
|
||||
ValueError: If required configuration parameters are missing
|
||||
Exception: If database initialization fails
|
||||
"""
|
||||
self.pool = None
|
||||
|
||||
try:
|
||||
# Create the appropriate connection pool based on DB type
|
||||
if ORACLE_DB_USE_WALLET:
|
||||
self._create_adb_pool()
|
||||
else: # DBCS
|
||||
self._create_dbcs_pool()
|
||||
|
||||
dsn = ORACLE_DB_DSN
|
||||
log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]")
|
||||
|
||||
with self.get_connection() as connection:
|
||||
log.info(f"Connection version: {connection.version}")
|
||||
self._initialize_database(connection)
|
||||
|
||||
log.info("Oracle Vector Search initialization complete.")
|
||||
except Exception as e:
|
||||
log.exception(f"Error during Oracle Vector Search initialization: {e}")
|
||||
raise
|
||||
|
||||
def _create_adb_pool(self) -> None:
|
||||
"""
|
||||
Create connection pool for Oracle Autonomous Database.
|
||||
|
||||
Uses wallet-based authentication.
|
||||
"""
|
||||
self.pool = oracledb.create_pool(
|
||||
user=ORACLE_DB_USER,
|
||||
password=ORACLE_DB_PASSWORD,
|
||||
dsn=ORACLE_DB_DSN,
|
||||
min=ORACLE_DB_POOL_MIN,
|
||||
max=ORACLE_DB_POOL_MAX,
|
||||
increment=ORACLE_DB_POOL_INCREMENT,
|
||||
config_dir=ORACLE_WALLET_DIR,
|
||||
wallet_location=ORACLE_WALLET_DIR,
|
||||
wallet_password=ORACLE_WALLET_PASSWORD,
|
||||
)
|
||||
log.info("Created ADB connection pool with wallet authentication.")
|
||||
|
||||
def _create_dbcs_pool(self) -> None:
|
||||
"""
|
||||
Create connection pool for Oracle Database Cloud Service.
|
||||
|
||||
Uses basic authentication without wallet.
|
||||
"""
|
||||
self.pool = oracledb.create_pool(
|
||||
user=ORACLE_DB_USER,
|
||||
password=ORACLE_DB_PASSWORD,
|
||||
dsn=ORACLE_DB_DSN,
|
||||
min=ORACLE_DB_POOL_MIN,
|
||||
max=ORACLE_DB_POOL_MAX,
|
||||
increment=ORACLE_DB_POOL_INCREMENT,
|
||||
)
|
||||
log.info("Created DB connection pool with basic authentication.")
|
||||
|
||||
def get_connection(self):
|
||||
"""
|
||||
Acquire a connection from the connection pool with retry logic.
|
||||
|
||||
Returns:
|
||||
connection: A database connection with output type handler configured
|
||||
"""
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
connection = self.pool.acquire()
|
||||
connection.outputtypehandler = self._output_type_handler
|
||||
return connection
|
||||
except oracledb.DatabaseError as e:
|
||||
(error_obj,) = e.args
|
||||
log.exception(
|
||||
f"Connection attempt {attempt + 1} failed: {error_obj.message}"
|
||||
)
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt
|
||||
log.info(f"Retrying in {wait_time} seconds...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
def start_health_monitor(self, interval_seconds: int = 60):
|
||||
"""
|
||||
Start a background thread to periodically check the health of the connection pool.
|
||||
|
||||
Args:
|
||||
interval_seconds (int): Number of seconds between health checks
|
||||
"""
|
||||
|
||||
def _monitor():
|
||||
while True:
|
||||
try:
|
||||
log.info("[HealthCheck] Running periodic DB health check...")
|
||||
self.ensure_connection()
|
||||
log.info("[HealthCheck] Connection is healthy.")
|
||||
except Exception as e:
|
||||
log.exception(f"[HealthCheck] Connection health check failed: {e}")
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
thread = threading.Thread(target=_monitor, daemon=True)
|
||||
thread.start()
|
||||
log.info(f"Started DB health monitor every {interval_seconds} seconds.")
|
||||
|
||||
def _reconnect_pool(self):
|
||||
"""
|
||||
Attempt to reinitialize the connection pool if it's been closed or broken.
|
||||
"""
|
||||
try:
|
||||
log.info("Attempting to reinitialize the Oracle connection pool...")
|
||||
|
||||
# Close existing pool if it exists
|
||||
if self.pool:
|
||||
try:
|
||||
self.pool.close()
|
||||
except Exception as close_error:
|
||||
log.warning(f"Error closing existing pool: {close_error}")
|
||||
|
||||
# Re-create the appropriate connection pool based on DB type
|
||||
if ORACLE_DB_USE_WALLET:
|
||||
self._create_adb_pool()
|
||||
else: # DBCS
|
||||
self._create_dbcs_pool()
|
||||
|
||||
log.info("Connection pool reinitialized.")
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to reinitialize the connection pool: {e}")
|
||||
raise
|
||||
|
||||
def ensure_connection(self):
|
||||
"""
|
||||
Ensure the database connection is alive, reconnecting pool if needed.
|
||||
"""
|
||||
try:
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1 FROM dual")
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Connection check failed: {e}, attempting to reconnect pool..."
|
||||
)
|
||||
self._reconnect_pool()
|
||||
|
||||
def _output_type_handler(self, cursor, metadata):
|
||||
"""
|
||||
Handle Oracle vector type conversion.
|
||||
|
||||
Args:
|
||||
cursor: Oracle database cursor
|
||||
metadata: Metadata for the column
|
||||
|
||||
Returns:
|
||||
A variable with appropriate conversion for vector types
|
||||
"""
|
||||
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
||||
return cursor.var(
|
||||
metadata.type_code, arraysize=cursor.arraysize, outconverter=list
|
||||
)
|
||||
|
||||
def _initialize_database(self, connection) -> None:
|
||||
"""
|
||||
Initialize database schema, tables and indexes.
|
||||
|
||||
Creates the document_chunk table and necessary indexes if they don't exist.
|
||||
|
||||
Args:
|
||||
connection: Oracle database connection
|
||||
|
||||
Raises:
|
||||
Exception: If schema initialization fails
|
||||
"""
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
log.info("Creating Table document_chunk")
|
||||
cursor.execute(
|
||||
"""
|
||||
BEGIN
|
||||
EXECUTE IMMEDIATE '
|
||||
CREATE TABLE IF NOT EXISTS document_chunk (
|
||||
id VARCHAR2(255) PRIMARY KEY,
|
||||
collection_name VARCHAR2(255) NOT NULL,
|
||||
text CLOB,
|
||||
vmetadata JSON,
|
||||
vector vector(*, float32)
|
||||
)
|
||||
';
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
IF SQLCODE != -955 THEN
|
||||
RAISE;
|
||||
END IF;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating Index document_chunk_collection_name_idx")
|
||||
cursor.execute(
|
||||
"""
|
||||
BEGIN
|
||||
EXECUTE IMMEDIATE '
|
||||
CREATE INDEX IF NOT EXISTS document_chunk_collection_name_idx
|
||||
ON document_chunk (collection_name)
|
||||
';
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
IF SQLCODE != -955 THEN
|
||||
RAISE;
|
||||
END IF;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx")
|
||||
cursor.execute(
|
||||
"""
|
||||
BEGIN
|
||||
EXECUTE IMMEDIATE '
|
||||
CREATE VECTOR INDEX IF NOT EXISTS document_chunk_vector_ivf_idx
|
||||
ON document_chunk(vector)
|
||||
ORGANIZATION NEIGHBOR PARTITIONS
|
||||
DISTANCE COSINE
|
||||
WITH TARGET ACCURACY 95
|
||||
PARAMETERS (TYPE IVF, NEIGHBOR PARTITIONS 100)
|
||||
';
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
IF SQLCODE != -955 THEN
|
||||
RAISE;
|
||||
END IF;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
log.info("Database initialization completed successfully.")
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
log.exception(f"Error during database initialization: {e}")
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
"""
|
||||
Check vector length compatibility (placeholder).
|
||||
|
||||
This method would check if the configured vector length matches the database schema.
|
||||
Currently implemented as a placeholder.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _vector_to_blob(self, vector: List[float]) -> bytes:
|
||||
"""
|
||||
Convert a vector to Oracle BLOB format.
|
||||
|
||||
Args:
|
||||
vector (List[float]): The vector to convert
|
||||
|
||||
Returns:
|
||||
bytes: The vector in Oracle BLOB format
|
||||
"""
|
||||
return array.array("f", vector)
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
"""
|
||||
Adjust vector to the expected length if needed.
|
||||
|
||||
Args:
|
||||
vector (List[float]): The vector to adjust
|
||||
|
||||
Returns:
|
||||
List[float]: The adjusted vector
|
||||
"""
|
||||
return vector
|
||||
|
||||
def _decimal_handler(self, obj):
|
||||
"""
|
||||
Handle Decimal objects for JSON serialization.
|
||||
|
||||
Args:
|
||||
obj: Object to serialize
|
||||
|
||||
Returns:
|
||||
float: Converted decimal value
|
||||
|
||||
Raises:
|
||||
TypeError: If object is not JSON serializable
|
||||
"""
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
raise TypeError(f"{obj} is not JSON serializable")
|
||||
|
||||
def _metadata_to_json(self, metadata: Dict) -> str:
|
||||
"""
|
||||
Convert metadata dictionary to JSON string.
|
||||
|
||||
Args:
|
||||
metadata (Dict): Metadata dictionary
|
||||
|
||||
Returns:
|
||||
str: JSON representation of metadata
|
||||
"""
|
||||
return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}"
|
||||
|
||||
def _json_to_metadata(self, json_str: str) -> Dict:
|
||||
"""
|
||||
Convert JSON string to metadata dictionary.
|
||||
|
||||
Args:
|
||||
json_str (str): JSON string
|
||||
|
||||
Returns:
|
||||
Dict: Metadata dictionary
|
||||
"""
|
||||
return json.loads(json_str) if json_str else {}
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""
|
||||
Insert vector items into the database.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection
|
||||
items (List[VectorItem]): List of vector items to insert
|
||||
|
||||
Raises:
|
||||
Exception: If insertion fails
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> items = [
|
||||
... {"id": "1", "text": "Sample text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
|
||||
... {"id": "2", "text": "Another text", "vector": [0.3, 0.4, ...], "metadata": {"source": "doc2"}}
|
||||
... ]
|
||||
>>> client.insert("my_collection", items)
|
||||
"""
|
||||
log.info(f"Inserting {len(items)} items into collection '{collection_name}'.")
|
||||
|
||||
with self.get_connection() as connection:
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
for item in items:
|
||||
vector_blob = self._vector_to_blob(item["vector"])
|
||||
metadata_json = self._metadata_to_json(item["metadata"])
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO document_chunk
|
||||
(id, collection_name, text, vmetadata, vector)
|
||||
VALUES (:id, :collection_name, :text, :metadata, :vector)
|
||||
""",
|
||||
{
|
||||
"id": item["id"],
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": metadata_json,
|
||||
"vector": vector_blob,
|
||||
},
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
log.info(
|
||||
f"Successfully inserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""
|
||||
Update or insert vector items into the database.
|
||||
|
||||
If an item with the same ID exists, it will be updated;
|
||||
otherwise, it will be inserted.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection
|
||||
items (List[VectorItem]): List of vector items to upsert
|
||||
|
||||
Raises:
|
||||
Exception: If upsert operation fails
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> items = [
|
||||
... {"id": "1", "text": "Updated text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
|
||||
... {"id": "3", "text": "New item", "vector": [0.5, 0.6, ...], "metadata": {"source": "doc3"}}
|
||||
... ]
|
||||
>>> client.upsert("my_collection", items)
|
||||
"""
|
||||
log.info(f"Upserting {len(items)} items into collection '{collection_name}'.")
|
||||
|
||||
with self.get_connection() as connection:
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
for item in items:
|
||||
vector_blob = self._vector_to_blob(item["vector"])
|
||||
metadata_json = self._metadata_to_json(item["metadata"])
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
MERGE INTO document_chunk d
|
||||
USING (SELECT :merge_id as id FROM dual) s
|
||||
ON (d.id = s.id)
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET
|
||||
collection_name = :upd_collection_name,
|
||||
text = :upd_text,
|
||||
vmetadata = :upd_metadata,
|
||||
vector = :upd_vector
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT (id, collection_name, text, vmetadata, vector)
|
||||
VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector)
|
||||
""",
|
||||
{
|
||||
"merge_id": item["id"],
|
||||
"upd_collection_name": collection_name,
|
||||
"upd_text": item["text"],
|
||||
"upd_metadata": metadata_json,
|
||||
"upd_vector": vector_blob,
|
||||
"ins_id": item["id"],
|
||||
"ins_collection_name": collection_name,
|
||||
"ins_text": item["text"],
|
||||
"ins_metadata": metadata_json,
|
||||
"ins_vector": vector_blob,
|
||||
},
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
log.info(
|
||||
f"Successfully upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""
|
||||
Search for similar vectors in the database.
|
||||
|
||||
Performs vector similarity search using cosine distance.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection to search
|
||||
vectors (List[List[Union[float, int]]]): Query vectors to find similar items for
|
||||
limit (int): Maximum number of results to return per query
|
||||
|
||||
Returns:
|
||||
Optional[SearchResult]: Search results containing ids, distances, documents, and metadata
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> query_vector = [0.1, 0.2, 0.3, ...] # Must match VECTOR_LENGTH
|
||||
>>> results = client.search("my_collection", [query_vector], limit=5)
|
||||
>>> if results:
|
||||
... log.info(f"Found {len(results.ids[0])} matches")
|
||||
... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])):
|
||||
... log.info(f"Match {i+1}: id={id}, distance={dist}")
|
||||
"""
|
||||
log.info(
|
||||
f"Searching items from collection '{collection_name}' with limit {limit}."
|
||||
)
|
||||
|
||||
try:
|
||||
if not vectors:
|
||||
log.warning("No vectors provided for search.")
|
||||
return None
|
||||
|
||||
num_queries = len(vectors)
|
||||
|
||||
ids = [[] for _ in range(num_queries)]
|
||||
distances = [[] for _ in range(num_queries)]
|
||||
documents = [[] for _ in range(num_queries)]
|
||||
metadatas = [[] for _ in range(num_queries)]
|
||||
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
for qid, vector in enumerate(vectors):
|
||||
vector_blob = self._vector_to_blob(vector)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT dc.id, dc.text,
|
||||
JSON_SERIALIZE(dc.vmetadata RETURNING VARCHAR2(4096)) as vmetadata,
|
||||
VECTOR_DISTANCE(dc.vector, :query_vector, COSINE) as distance
|
||||
FROM document_chunk dc
|
||||
WHERE dc.collection_name = :collection_name
|
||||
ORDER BY VECTOR_DISTANCE(dc.vector, :query_vector, COSINE)
|
||||
FETCH APPROX FIRST :limit ROWS ONLY
|
||||
""",
|
||||
{
|
||||
"query_vector": vector_blob,
|
||||
"collection_name": collection_name,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
ids[qid].append(row[0])
|
||||
documents[qid].append(
|
||||
row[1].read()
|
||||
if isinstance(row[1], oracledb.LOB)
|
||||
else str(row[1])
|
||||
)
|
||||
# 🔧 FIXED: Parse JSON metadata properly
|
||||
metadata_str = (
|
||||
row[2].read()
|
||||
if isinstance(row[2], oracledb.LOB)
|
||||
else row[2]
|
||||
)
|
||||
metadatas[qid].append(self._json_to_metadata(metadata_str))
|
||||
distances[qid].append(float(row[3]))
|
||||
|
||||
log.info(
|
||||
f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results."
|
||||
)
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
"""
|
||||
Query items based on metadata filters.
|
||||
|
||||
Retrieves items that match specified metadata criteria.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection to query
|
||||
filter (Dict[str, Any]): Metadata filters to apply
|
||||
limit (Optional[int]): Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
Optional[GetResult]: Query results containing ids, documents, and metadata
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> filter = {"source": "doc1", "category": "finance"}
|
||||
>>> results = client.query("my_collection", filter, limit=20)
|
||||
>>> if results:
|
||||
... print(f"Found {len(results.ids[0])} matching documents")
|
||||
"""
|
||||
log.info(f"Querying items from collection '{collection_name}' with filters.")
|
||||
|
||||
try:
|
||||
limit = limit or 100
|
||||
|
||||
query = """
|
||||
SELECT id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
|
||||
FROM document_chunk
|
||||
WHERE collection_name = :collection_name
|
||||
"""
|
||||
|
||||
params = {"collection_name": collection_name}
|
||||
|
||||
for i, (key, value) in enumerate(filter.items()):
|
||||
param_name = f"value_{i}"
|
||||
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
|
||||
params[param_name] = str(value)
|
||||
|
||||
query += " FETCH FIRST :limit ROWS ONLY"
|
||||
params["limit"] = limit
|
||||
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
log.info("No results found for query.")
|
||||
return None
|
||||
|
||||
ids = [[row[0] for row in results]]
|
||||
documents = [
|
||||
[
|
||||
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
# 🔧 FIXED: Parse JSON metadata properly
|
||||
metadatas = [
|
||||
[
|
||||
self._json_to_metadata(
|
||||
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
|
||||
)
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
|
||||
log.info(f"Query completed. Found {len(results)} results.")
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""
|
||||
Get all items in a collection.
|
||||
|
||||
Retrieves items from a specified collection up to the limit.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection to retrieve
|
||||
limit (Optional[int]): Maximum number of items to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[GetResult]: Result containing ids, documents, and metadata
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> results = client.get("my_collection", limit=50)
|
||||
>>> if results:
|
||||
... print(f"Retrieved {len(results.ids[0])} documents from collection")
|
||||
"""
|
||||
log.info(
|
||||
f"Getting items from collection '{collection_name}' with limit {limit}."
|
||||
)
|
||||
|
||||
try:
|
||||
limit = limit or 1000
|
||||
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT /*+ MONITOR */ id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
|
||||
FROM document_chunk
|
||||
WHERE collection_name = :collection_name
|
||||
FETCH FIRST :limit ROWS ONLY
|
||||
""",
|
||||
{"collection_name": collection_name, "limit": limit},
|
||||
)
|
||||
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
log.info("No results found.")
|
||||
return None
|
||||
|
||||
ids = [[row[0] for row in results]]
|
||||
documents = [
|
||||
[
|
||||
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
# 🔧 FIXED: Parse JSON metadata properly
|
||||
metadatas = [
|
||||
[
|
||||
self._json_to_metadata(
|
||||
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
|
||||
)
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Delete items from the database.
|
||||
|
||||
Deletes items from a collection based on IDs or metadata filters.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection to delete from
|
||||
ids (Optional[List[str]]): Specific item IDs to delete
|
||||
filter (Optional[Dict[str, Any]]): Metadata filters for deletion
|
||||
|
||||
Raises:
|
||||
Exception: If deletion fails
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> # Delete specific items by ID
|
||||
>>> client.delete("my_collection", ids=["1", "3", "5"])
|
||||
>>> # Or delete by metadata filter
|
||||
>>> client.delete("my_collection", filter={"source": "deprecated_source"})
|
||||
"""
|
||||
log.info(f"Deleting items from collection '{collection_name}'.")
|
||||
|
||||
try:
|
||||
query = (
|
||||
"DELETE FROM document_chunk WHERE collection_name = :collection_name"
|
||||
)
|
||||
params = {"collection_name": collection_name}
|
||||
|
||||
if ids:
|
||||
# 🔧 FIXED: Use proper parameterized query to prevent SQL injection
|
||||
placeholders = ",".join([f":id_{i}" for i in range(len(ids))])
|
||||
query += f" AND id IN ({placeholders})"
|
||||
for i, id_val in enumerate(ids):
|
||||
params[f"id_{i}"] = id_val
|
||||
|
||||
if filter:
|
||||
for i, (key, value) in enumerate(filter.items()):
|
||||
param_name = f"value_{i}"
|
||||
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
|
||||
params[param_name] = str(value)
|
||||
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, params)
|
||||
deleted = cursor.rowcount
|
||||
connection.commit()
|
||||
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during delete: {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the database by deleting all items.
|
||||
|
||||
Deletes all items from the document_chunk table.
|
||||
|
||||
Raises:
|
||||
Exception: If reset fails
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> client.reset() # Warning: Removes all data!
|
||||
"""
|
||||
log.info("Resetting database - deleting all items.")
|
||||
|
||||
try:
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM document_chunk")
|
||||
deleted = cursor.rowcount
|
||||
connection.commit()
|
||||
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the database connection pool.
|
||||
|
||||
Properly closes the connection pool and releases all resources.
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> # After finishing all operations
|
||||
>>> client.close()
|
||||
"""
|
||||
try:
|
||||
if hasattr(self, "pool") and self.pool:
|
||||
self.pool.close()
|
||||
log.info("Oracle Vector Search connection pool closed.")
|
||||
except Exception as e:
|
||||
log.exception(f"Error closing connection pool: {e}")
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Check if a collection exists.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection to check
|
||||
|
||||
Returns:
|
||||
bool: True if the collection exists, False otherwise
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> if client.has_collection("my_collection"):
|
||||
... print("Collection exists!")
|
||||
... else:
|
||||
... print("Collection does not exist.")
|
||||
"""
|
||||
try:
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM document_chunk
|
||||
WHERE collection_name = :collection_name
|
||||
FETCH FIRST 1 ROWS ONLY
|
||||
""",
|
||||
{"collection_name": collection_name},
|
||||
)
|
||||
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""
|
||||
Delete an entire collection.
|
||||
|
||||
Removes all items belonging to the specified collection.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection to delete
|
||||
|
||||
Example:
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> client.delete_collection("obsolete_collection")
|
||||
"""
|
||||
log.info(f"Deleting collection '{collection_name}'.")
|
||||
|
||||
try:
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM document_chunk
|
||||
WHERE collection_name = :collection_name
|
||||
""",
|
||||
{"collection_name": collection_name},
|
||||
)
|
||||
|
||||
deleted = cursor.rowcount
|
||||
connection.commit()
|
||||
|
||||
log.info(
|
||||
f"Collection '{collection_name}' deleted. Removed {deleted} items."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error deleting collection '{collection_name}': {e}")
|
||||
raise
|
||||
|
|
@ -26,6 +26,8 @@ from pgvector.sqlalchemy import Vector
|
|||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
|
||||
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
|
|
@ -109,11 +111,35 @@ class PgvectorClient(VectorDBBase):
|
|||
|
||||
try:
|
||||
# Ensure the pgvector extension is available
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Ensure the pgcrypto extension is available for encryption
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS pgcrypto;"))
|
||||
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
|
||||
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not PGVECTOR_PGCRYPTO_KEY:
|
||||
raise ValueError(
|
||||
|
|
@ -201,6 +227,8 @@ class PgvectorClient(VectorDBBase):
|
|||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
# Use raw SQL for BYTEA/pgcrypto
|
||||
# Ensure metadata is converted to its JSON text representation
|
||||
json_metadata = json.dumps(item["metadata"])
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
|
|
@ -209,7 +237,7 @@ class PgvectorClient(VectorDBBase):
|
|||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
pgp_sym_encrypt(:metadata_text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
|
|
@ -219,7 +247,7 @@ class PgvectorClient(VectorDBBase):
|
|||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"metadata_text": json_metadata,
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
|
|
@ -235,7 +263,7 @@ class PgvectorClient(VectorDBBase):
|
|||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
vmetadata=stringify_metadata(item["metadata"]),
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
|
|
@ -253,6 +281,7 @@ class PgvectorClient(VectorDBBase):
|
|||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
json_metadata = json.dumps(item["metadata"])
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
|
|
@ -261,7 +290,7 @@ class PgvectorClient(VectorDBBase):
|
|||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
pgp_sym_encrypt(:metadata_text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
vector = EXCLUDED.vector,
|
||||
|
|
@ -275,7 +304,7 @@ class PgvectorClient(VectorDBBase):
|
|||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"metadata_text": json_metadata,
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
|
|
@ -292,7 +321,7 @@ class PgvectorClient(VectorDBBase):
|
|||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.vmetadata = stringify_metadata(item["metadata"])
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
|
|
@ -302,7 +331,7 @@ class PgvectorClient(VectorDBBase):
|
|||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
vmetadata=stringify_metadata(item["metadata"]),
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
|
|
@ -416,10 +445,12 @@ class PgvectorClient(VectorDBBase):
|
|||
documents[qid].append(row.text)
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
self.session.rollback() # read-only transaction
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
|
|
@ -472,12 +503,14 @@ class PgvectorClient(VectorDBBase):
|
|||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
self.session.rollback() # read-only transaction
|
||||
return GetResult(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
|
|
@ -518,8 +551,10 @@ class PgvectorClient(VectorDBBase):
|
|||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
self.session.rollback() # read-only transaction
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
|
|
@ -587,8 +622,10 @@ class PgvectorClient(VectorDBBase):
|
|||
.first()
|
||||
is not None
|
||||
)
|
||||
self.session.rollback() # read-only transaction
|
||||
return exists
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ from open_webui.config import (
|
|||
PINECONE_CLOUD,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||
|
||||
|
||||
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
|
||||
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
|
||||
|
|
@ -183,7 +185,7 @@ class PineconeClient(VectorDBBase):
|
|||
point = {
|
||||
"id": item["id"],
|
||||
"values": item["vector"],
|
||||
"metadata": metadata,
|
||||
"metadata": stringify_metadata(metadata),
|
||||
}
|
||||
points.append(point)
|
||||
return points
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ from open_webui.config import (
|
|||
QDRANT_GRPC_PORT,
|
||||
QDRANT_PREFER_GRPC,
|
||||
QDRANT_COLLECTION_PREFIX,
|
||||
QDRANT_TIMEOUT,
|
||||
QDRANT_HNSW_M,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
|
@ -36,6 +38,8 @@ class QdrantClient(VectorDBBase):
|
|||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
|
||||
self.QDRANT_HNSW_M = QDRANT_HNSW_M
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
self.client = None
|
||||
|
|
@ -53,9 +57,14 @@ class QdrantClient(VectorDBBase):
|
|||
grpc_port=self.GRPC_PORT,
|
||||
prefer_grpc=self.PREFER_GRPC,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
timeout=self.QDRANT_TIMEOUT,
|
||||
)
|
||||
else:
|
||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
self.client = Qclient(
|
||||
url=self.QDRANT_URI,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
timeout=QDRANT_TIMEOUT,
|
||||
)
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
|
|
@ -85,6 +94,9 @@ class QdrantClient(VectorDBBase):
|
|||
distance=models.Distance.COSINE,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
m=self.QDRANT_HNSW_M,
|
||||
),
|
||||
)
|
||||
|
||||
# Create payload indexes for efficient filtering
|
||||
|
|
@ -171,23 +183,23 @@ class QdrantClient(VectorDBBase):
|
|||
)
|
||||
)
|
||||
|
||||
points = self.client.query_points(
|
||||
points = self.client.scroll(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query_filter=models.Filter(should=field_conditions),
|
||||
scroll_filter=models.Filter(should=field_conditions),
|
||||
limit=limit,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
return self._result_to_get_result(points[0])
|
||||
except Exception as e:
|
||||
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.query_points(
|
||||
points = self.client.scroll(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
return self._result_to_get_result(points[0])
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from open_webui.config import (
|
|||
QDRANT_PREFER_GRPC,
|
||||
QDRANT_URI,
|
||||
QDRANT_COLLECTION_PREFIX,
|
||||
QDRANT_TIMEOUT,
|
||||
QDRANT_HNSW_M,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.vector.main import (
|
||||
|
|
@ -51,6 +53,8 @@ class QdrantClient(VectorDBBase):
|
|||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
|
||||
self.QDRANT_HNSW_M = QDRANT_HNSW_M
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
raise ValueError(
|
||||
|
|
@ -69,9 +73,14 @@ class QdrantClient(VectorDBBase):
|
|||
grpc_port=self.GRPC_PORT,
|
||||
prefer_grpc=self.PREFER_GRPC,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
timeout=self.QDRANT_TIMEOUT,
|
||||
)
|
||||
if self.PREFER_GRPC
|
||||
else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
else Qclient(
|
||||
url=self.QDRANT_URI,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
timeout=self.QDRANT_TIMEOUT,
|
||||
)
|
||||
)
|
||||
|
||||
# Main collection types for multi-tenancy
|
||||
|
|
@ -133,6 +142,12 @@ class QdrantClient(VectorDBBase):
|
|||
distance=models.Distance.COSINE,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
# Disable global index building due to multitenancy
|
||||
# For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=self.QDRANT_HNSW_M,
|
||||
m=0,
|
||||
),
|
||||
)
|
||||
log.info(
|
||||
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
|
||||
|
|
@ -278,12 +293,12 @@ class QdrantClient(VectorDBBase):
|
|||
tenant_filter = _tenant_filter(tenant_id)
|
||||
field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
|
||||
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
|
||||
points = self.client.query_points(
|
||||
points = self.client.scroll(
|
||||
collection_name=mt_collection,
|
||||
query_filter=combined_filter,
|
||||
scroll_filter=combined_filter,
|
||||
limit=limit,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
return self._result_to_get_result(points[0])
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""
|
||||
|
|
@ -296,12 +311,12 @@ class QdrantClient(VectorDBBase):
|
|||
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
|
||||
return None
|
||||
tenant_filter = _tenant_filter(tenant_id)
|
||||
points = self.client.query_points(
|
||||
points = self.client.scroll(
|
||||
collection_name=mt_collection,
|
||||
query_filter=models.Filter(must=[tenant_filter]),
|
||||
scroll_filter=models.Filter(must=[tenant_filter]),
|
||||
limit=NO_LIMIT,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
return self._result_to_get_result(points[0])
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]):
|
||||
"""
|
||||
|
|
|
|||
775
backend/open_webui/retrieval/vector/dbs/s3vector.py
Normal file
775
backend/open_webui/retrieval/vector/dbs/s3vector.py
Normal file
|
|
@ -0,0 +1,775 @@
|
|||
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
GetResult,
|
||||
SearchResult,
|
||||
)
|
||||
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
import logging
|
||||
import boto3
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class S3VectorClient(VectorDBBase):
|
||||
"""
|
||||
AWS S3 Vector integration for Open WebUI Knowledge.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.bucket_name = S3_VECTOR_BUCKET_NAME
|
||||
self.region = S3_VECTOR_REGION
|
||||
|
||||
# Simple validation - log warnings instead of raising exceptions
|
||||
if not self.bucket_name:
|
||||
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
|
||||
if not self.region:
|
||||
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
|
||||
|
||||
if self.bucket_name and self.region:
|
||||
try:
|
||||
self.client = boto3.client("s3vectors", region_name=self.region)
|
||||
log.info(
|
||||
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize S3Vector client: {e}")
|
||||
self.client = None
|
||||
else:
|
||||
self.client = None
|
||||
|
||||
def _create_index(
|
||||
self,
|
||||
index_name: str,
|
||||
dimension: int,
|
||||
data_type: str = "float32",
|
||||
distance_metric: str = "cosine",
|
||||
) -> None:
|
||||
"""
|
||||
Create a new index in the S3 vector bucket for the given collection if it does not exist.
|
||||
"""
|
||||
if self.has_collection(index_name):
|
||||
log.debug(f"Index '{index_name}' already exists, skipping creation")
|
||||
return
|
||||
|
||||
try:
|
||||
self.client.create_index(
|
||||
vectorBucketName=self.bucket_name,
|
||||
indexName=index_name,
|
||||
dataType=data_type,
|
||||
dimension=dimension,
|
||||
distanceMetric=distance_metric,
|
||||
)
|
||||
log.info(
|
||||
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error creating S3 index '{index_name}': {e}")
|
||||
raise
|
||||
|
||||
def _filter_metadata(
|
||||
self, metadata: Dict[str, Any], item_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
|
||||
"""
|
||||
if not isinstance(metadata, dict) or len(metadata) <= 10:
|
||||
return metadata
|
||||
|
||||
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
|
||||
important_keys = [
|
||||
"text", # The actual document content
|
||||
"file_id", # File ID
|
||||
"source", # Document source file
|
||||
"title", # Document title
|
||||
"page", # Page number
|
||||
"total_pages", # Total pages in document
|
||||
"embedding_config", # Embedding configuration
|
||||
"created_by", # User who created it
|
||||
"name", # Document name
|
||||
"hash", # Content hash
|
||||
]
|
||||
filtered_metadata = {}
|
||||
|
||||
# First, add important keys if they exist
|
||||
for key in important_keys:
|
||||
if key in metadata:
|
||||
filtered_metadata[key] = metadata[key]
|
||||
if len(filtered_metadata) >= 10:
|
||||
break
|
||||
|
||||
# If we still have room, add other keys
|
||||
if len(filtered_metadata) < 10:
|
||||
for key, value in metadata.items():
|
||||
if key not in filtered_metadata:
|
||||
filtered_metadata[key] = value
|
||||
if len(filtered_metadata) >= 10:
|
||||
break
|
||||
|
||||
log.warning(
|
||||
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
|
||||
)
|
||||
return filtered_metadata
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Check if a vector index (collection) exists in the S3 vector bucket.
|
||||
"""
|
||||
|
||||
try:
|
||||
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||
indexes = response.get("indexes", [])
|
||||
return any(idx.get("indexName") == collection_name for idx in indexes)
|
||||
except Exception as e:
|
||||
log.error(f"Error listing indexes: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""
|
||||
Delete an entire S3 Vector index/collection.
|
||||
"""
|
||||
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
log.info(f"Deleting collection '{collection_name}'")
|
||||
self.client.delete_index(
|
||||
vectorBucketName=self.bucket_name, indexName=collection_name
|
||||
)
|
||||
log.info(f"Successfully deleted collection '{collection_name}'")
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection '{collection_name}': {e}")
|
||||
raise
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""
|
||||
Insert vector items into the S3 Vector index. Create index if it does not exist.
|
||||
"""
|
||||
if not items:
|
||||
log.warning("No items to insert")
|
||||
return
|
||||
|
||||
dimension = len(items[0]["vector"])
|
||||
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
log.info(f"Index '{collection_name}' does not exist. Creating index.")
|
||||
self._create_index(
|
||||
index_name=collection_name,
|
||||
dimension=dimension,
|
||||
data_type="float32",
|
||||
distance_metric="cosine",
|
||||
)
|
||||
|
||||
# Prepare vectors for insertion
|
||||
vectors = []
|
||||
for item in items:
|
||||
# Ensure vector data is in the correct format for S3 Vector API
|
||||
vector_data = item["vector"]
|
||||
if isinstance(vector_data, list):
|
||||
# Convert list to float32 values as required by S3 Vector API
|
||||
vector_data = [float(x) for x in vector_data]
|
||||
|
||||
# Prepare metadata, ensuring the text field is preserved
|
||||
metadata = item.get("metadata", {}).copy()
|
||||
|
||||
# Add the text field to metadata so it's available for retrieval
|
||||
metadata["text"] = item["text"]
|
||||
|
||||
# Convert metadata to string format for consistency
|
||||
metadata = stringify_metadata(metadata)
|
||||
|
||||
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||
metadata = self._filter_metadata(metadata, item["id"])
|
||||
|
||||
vectors.append(
|
||||
{
|
||||
"key": item["id"],
|
||||
"data": {"float32": vector_data},
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# Insert vectors in batches of 500 (S3 Vector API limit)
|
||||
batch_size = 500
|
||||
for i in range(0, len(vectors), batch_size):
|
||||
batch = vectors[i : i + batch_size]
|
||||
self.client.put_vectors(
|
||||
vectorBucketName=self.bucket_name,
|
||||
indexName=collection_name,
|
||||
vectors=batch,
|
||||
)
|
||||
log.info(
|
||||
f"Inserted batch {i//batch_size + 1}: {len(batch)} vectors into index '{collection_name}'."
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error inserting vectors: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""
|
||||
Insert or update vector items in the S3 Vector index. Create index if it does not exist.
|
||||
"""
|
||||
if not items:
|
||||
log.warning("No items to upsert")
|
||||
return
|
||||
|
||||
dimension = len(items[0]["vector"])
|
||||
log.info(f"Upsert dimension: {dimension}")
|
||||
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
log.info(
|
||||
f"Index '{collection_name}' does not exist. Creating index for upsert."
|
||||
)
|
||||
self._create_index(
|
||||
index_name=collection_name,
|
||||
dimension=dimension,
|
||||
data_type="float32",
|
||||
distance_metric="cosine",
|
||||
)
|
||||
|
||||
# Prepare vectors for upsert
|
||||
vectors = []
|
||||
for item in items:
|
||||
# Ensure vector data is in the correct format for S3 Vector API
|
||||
vector_data = item["vector"]
|
||||
if isinstance(vector_data, list):
|
||||
# Convert list to float32 values as required by S3 Vector API
|
||||
vector_data = [float(x) for x in vector_data]
|
||||
|
||||
# Prepare metadata, ensuring the text field is preserved
|
||||
metadata = item.get("metadata", {}).copy()
|
||||
# Add the text field to metadata so it's available for retrieval
|
||||
metadata["text"] = item["text"]
|
||||
|
||||
# Convert metadata to string format for consistency
|
||||
metadata = stringify_metadata(metadata)
|
||||
|
||||
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||
metadata = self._filter_metadata(metadata, item["id"])
|
||||
|
||||
vectors.append(
|
||||
{
|
||||
"key": item["id"],
|
||||
"data": {"float32": vector_data},
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# Upsert vectors in batches of 500 (S3 Vector API limit)
|
||||
batch_size = 500
|
||||
for i in range(0, len(vectors), batch_size):
|
||||
batch = vectors[i : i + batch_size]
|
||||
if i == 0: # Log sample info for first batch only
|
||||
log.info(
|
||||
f"Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]['key']}, data_type={type(batch[0]['data']['float32'])}, data_len={len(batch[0]['data']['float32'])}"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"Upserting batch {i//batch_size + 1}: {len(batch)} vectors."
|
||||
)
|
||||
|
||||
self.client.put_vectors(
|
||||
vectorBucketName=self.bucket_name,
|
||||
indexName=collection_name,
|
||||
vectors=batch,
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error upserting vectors: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""
|
||||
Search for similar vectors in a collection using multiple query vectors.
|
||||
"""
|
||||
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(f"Collection '{collection_name}' does not exist")
|
||||
return None
|
||||
|
||||
if not vectors:
|
||||
log.warning("No query vectors provided")
|
||||
return None
|
||||
|
||||
try:
|
||||
log.info(
|
||||
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
|
||||
)
|
||||
|
||||
# Initialize result lists
|
||||
all_ids = []
|
||||
all_documents = []
|
||||
all_metadatas = []
|
||||
all_distances = []
|
||||
|
||||
# Process each query vector
|
||||
for i, query_vector in enumerate(vectors):
|
||||
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
|
||||
|
||||
# Prepare the query vector in S3 Vector format
|
||||
query_vector_dict = {"float32": [float(x) for x in query_vector]}
|
||||
|
||||
# Call S3 Vector query API
|
||||
response = self.client.query_vectors(
|
||||
vectorBucketName=self.bucket_name,
|
||||
indexName=collection_name,
|
||||
topK=limit,
|
||||
queryVector=query_vector_dict,
|
||||
returnMetadata=True,
|
||||
returnDistance=True,
|
||||
)
|
||||
|
||||
# Process results for this query
|
||||
query_ids = []
|
||||
query_documents = []
|
||||
query_metadatas = []
|
||||
query_distances = []
|
||||
|
||||
result_vectors = response.get("vectors", [])
|
||||
|
||||
for vector in result_vectors:
|
||||
vector_id = vector.get("key")
|
||||
vector_metadata = vector.get("metadata", {})
|
||||
vector_distance = vector.get("distance", 0.0)
|
||||
|
||||
# Extract document text from metadata
|
||||
document_text = ""
|
||||
if isinstance(vector_metadata, dict):
|
||||
# Get the text field first (highest priority)
|
||||
document_text = vector_metadata.get("text")
|
||||
if not document_text:
|
||||
# Fallback to other possible text fields
|
||||
document_text = (
|
||||
vector_metadata.get("content")
|
||||
or vector_metadata.get("document")
|
||||
or vector_id
|
||||
)
|
||||
else:
|
||||
document_text = vector_id
|
||||
|
||||
query_ids.append(vector_id)
|
||||
query_documents.append(document_text)
|
||||
query_metadatas.append(vector_metadata)
|
||||
query_distances.append(vector_distance)
|
||||
|
||||
# Add this query's results to the overall results
|
||||
all_ids.append(query_ids)
|
||||
all_documents.append(query_documents)
|
||||
all_metadatas.append(query_metadatas)
|
||||
all_distances.append(query_distances)
|
||||
|
||||
log.info(f"Search completed. Found results for {len(all_ids)} queries")
|
||||
|
||||
# Return SearchResult format
|
||||
return SearchResult(
|
||||
ids=all_ids if all_ids else None,
|
||||
documents=all_documents if all_documents else None,
|
||||
metadatas=all_metadatas if all_metadatas else None,
|
||||
distances=all_distances if all_distances else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error searching collection '{collection_name}': {str(e)}")
|
||||
# Handle specific AWS exceptions
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "NotFoundException":
|
||||
log.warning(f"Collection '{collection_name}' not found")
|
||||
return None
|
||||
elif error_code == "ValidationException":
|
||||
log.error(f"Invalid query vector dimensions or parameters")
|
||||
return None
|
||||
elif error_code == "AccessDeniedException":
|
||||
log.error(
|
||||
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||
)
|
||||
return None
|
||||
raise
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
"""
|
||||
Query vectors from a collection using metadata filter.
|
||||
"""
|
||||
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(f"Collection '{collection_name}' does not exist")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
if not filter:
|
||||
log.warning("No filter provided, returning all vectors")
|
||||
return self.get(collection_name)
|
||||
|
||||
try:
|
||||
log.info(f"Querying collection '{collection_name}' with filter: {filter}")
|
||||
|
||||
# For S3 Vector, we need to use list_vectors and then filter results
|
||||
# Since S3 Vector may not support complex server-side filtering,
|
||||
# we'll retrieve all vectors and filter client-side
|
||||
|
||||
# Get all vectors first
|
||||
all_vectors_result = self.get(collection_name)
|
||||
|
||||
if not all_vectors_result or not all_vectors_result.ids:
|
||||
log.warning("No vectors found in collection")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
# Extract the lists from the result
|
||||
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
|
||||
all_documents = (
|
||||
all_vectors_result.documents[0] if all_vectors_result.documents else []
|
||||
)
|
||||
all_metadatas = (
|
||||
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
|
||||
)
|
||||
|
||||
# Apply client-side filtering
|
||||
filtered_ids = []
|
||||
filtered_documents = []
|
||||
filtered_metadatas = []
|
||||
|
||||
for i, metadata in enumerate(all_metadatas):
|
||||
if self._matches_filter(metadata, filter):
|
||||
if i < len(all_ids):
|
||||
filtered_ids.append(all_ids[i])
|
||||
if i < len(all_documents):
|
||||
filtered_documents.append(all_documents[i])
|
||||
filtered_metadatas.append(metadata)
|
||||
|
||||
# Apply limit if specified
|
||||
if limit and len(filtered_ids) >= limit:
|
||||
break
|
||||
|
||||
log.info(
|
||||
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
|
||||
)
|
||||
|
||||
# Return GetResult format
|
||||
if filtered_ids:
|
||||
return GetResult(
|
||||
ids=[filtered_ids],
|
||||
documents=[filtered_documents],
|
||||
metadatas=[filtered_metadatas],
|
||||
)
|
||||
else:
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error querying collection '{collection_name}': {str(e)}")
|
||||
# Handle specific AWS exceptions
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "NotFoundException":
|
||||
log.warning(f"Collection '{collection_name}' not found")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
elif error_code == "AccessDeniedException":
|
||||
log.error(
|
||||
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||
)
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
raise
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""
|
||||
Retrieve all vectors from a collection.
|
||||
"""
|
||||
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(f"Collection '{collection_name}' does not exist")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
try:
|
||||
log.info(f"Retrieving all vectors from collection '{collection_name}'")
|
||||
|
||||
# Initialize result lists
|
||||
all_ids = []
|
||||
all_documents = []
|
||||
all_metadatas = []
|
||||
|
||||
# Handle pagination
|
||||
next_token = None
|
||||
|
||||
while True:
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"vectorBucketName": self.bucket_name,
|
||||
"indexName": collection_name,
|
||||
"returnData": False, # Don't include vector data (not needed for get)
|
||||
"returnMetadata": True, # Include metadata
|
||||
"maxResults": 500, # Use reasonable page size
|
||||
}
|
||||
|
||||
if next_token:
|
||||
request_params["nextToken"] = next_token
|
||||
|
||||
# Call S3 Vector API
|
||||
response = self.client.list_vectors(**request_params)
|
||||
|
||||
# Process vectors in this page
|
||||
vectors = response.get("vectors", [])
|
||||
|
||||
for vector in vectors:
|
||||
vector_id = vector.get("key")
|
||||
vector_data = vector.get("data", {})
|
||||
vector_metadata = vector.get("metadata", {})
|
||||
|
||||
# Extract the actual vector array
|
||||
vector_array = vector_data.get("float32", [])
|
||||
|
||||
# For documents, we try to extract text from metadata or use the vector ID
|
||||
document_text = ""
|
||||
if isinstance(vector_metadata, dict):
|
||||
# Get the text field first (highest priority)
|
||||
document_text = vector_metadata.get("text")
|
||||
if not document_text:
|
||||
# Fallback to other possible text fields
|
||||
document_text = (
|
||||
vector_metadata.get("content")
|
||||
or vector_metadata.get("document")
|
||||
or vector_id
|
||||
)
|
||||
|
||||
# Log the actual content for debugging
|
||||
log.debug(
|
||||
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
|
||||
)
|
||||
else:
|
||||
document_text = vector_id
|
||||
|
||||
all_ids.append(vector_id)
|
||||
all_documents.append(document_text)
|
||||
all_metadatas.append(vector_metadata)
|
||||
|
||||
# Check if there are more pages
|
||||
next_token = response.get("nextToken")
|
||||
if not next_token:
|
||||
break
|
||||
|
||||
log.info(
|
||||
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
|
||||
)
|
||||
|
||||
# Return in GetResult format
|
||||
# The Open WebUI GetResult expects lists of lists, so we wrap each list
|
||||
if all_ids:
|
||||
return GetResult(
|
||||
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
|
||||
)
|
||||
else:
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
|
||||
)
|
||||
# Handle specific AWS exceptions
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "NotFoundException":
|
||||
log.warning(f"Collection '{collection_name}' not found")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
elif error_code == "AccessDeniedException":
|
||||
log.error(
|
||||
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||
)
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
raise
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Delete vectors by ID or filter from a collection.
|
||||
"""
|
||||
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if this is a knowledge collection (not file-specific)
|
||||
is_knowledge_collection = not collection_name.startswith("file-")
|
||||
|
||||
try:
|
||||
if ids:
|
||||
# Delete by specific vector IDs/keys
|
||||
log.info(
|
||||
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
|
||||
)
|
||||
self.client.delete_vectors(
|
||||
vectorBucketName=self.bucket_name,
|
||||
indexName=collection_name,
|
||||
keys=ids,
|
||||
)
|
||||
log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
|
||||
|
||||
elif filter:
|
||||
# Handle filter-based deletion
|
||||
log.info(
|
||||
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
|
||||
)
|
||||
|
||||
# If this is a knowledge collection and we have a file_id filter,
|
||||
# also clean up the corresponding file-specific collection
|
||||
if is_knowledge_collection and "file_id" in filter:
|
||||
file_id = filter["file_id"]
|
||||
file_collection_name = f"file-{file_id}"
|
||||
if self.has_collection(file_collection_name):
|
||||
log.info(
|
||||
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
|
||||
)
|
||||
self.delete_collection(file_collection_name)
|
||||
|
||||
# For the main collection, implement query-then-delete
|
||||
# First, query to get IDs matching the filter
|
||||
query_result = self.query(collection_name, filter)
|
||||
if query_result and query_result.ids and query_result.ids[0]:
|
||||
matching_ids = query_result.ids[0]
|
||||
log.info(
|
||||
f"Found {len(matching_ids)} vectors matching filter, deleting them"
|
||||
)
|
||||
|
||||
# Delete the matching vectors by ID
|
||||
self.client.delete_vectors(
|
||||
vectorBucketName=self.bucket_name,
|
||||
indexName=collection_name,
|
||||
keys=matching_ids,
|
||||
)
|
||||
log.info(
|
||||
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
|
||||
)
|
||||
else:
|
||||
log.warning("No vectors found matching the filter criteria")
|
||||
else:
|
||||
log.warning("No IDs or filter provided for deletion")
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error deleting vectors from collection '{collection_name}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset/clear all vector data. For S3 Vector, this deletes all indexes.
|
||||
"""
|
||||
|
||||
try:
|
||||
log.warning(
|
||||
"Reset called - this will delete all vector indexes in the S3 bucket"
|
||||
)
|
||||
|
||||
# List all indexes
|
||||
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||
indexes = response.get("indexes", [])
|
||||
|
||||
if not indexes:
|
||||
log.warning("No indexes found to delete")
|
||||
return
|
||||
|
||||
# Delete all indexes
|
||||
deleted_count = 0
|
||||
for index in indexes:
|
||||
index_name = index.get("indexName")
|
||||
if index_name:
|
||||
try:
|
||||
self.client.delete_index(
|
||||
vectorBucketName=self.bucket_name, indexName=index_name
|
||||
)
|
||||
deleted_count += 1
|
||||
log.info(f"Deleted index: {index_name}")
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting index '{index_name}': {e}")
|
||||
|
||||
log.info(f"Reset completed: deleted {deleted_count} indexes")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if metadata matches the given filter conditions.
|
||||
"""
|
||||
if not isinstance(metadata, dict) or not isinstance(filter, dict):
|
||||
return False
|
||||
|
||||
# Check each filter condition
|
||||
for key, expected_value in filter.items():
|
||||
# Handle special operators
|
||||
if key.startswith("$"):
|
||||
if key == "$and":
|
||||
# All conditions must match
|
||||
if not isinstance(expected_value, list):
|
||||
continue
|
||||
for condition in expected_value:
|
||||
if not self._matches_filter(metadata, condition):
|
||||
return False
|
||||
elif key == "$or":
|
||||
# At least one condition must match
|
||||
if not isinstance(expected_value, list):
|
||||
continue
|
||||
any_match = False
|
||||
for condition in expected_value:
|
||||
if self._matches_filter(metadata, condition):
|
||||
any_match = True
|
||||
break
|
||||
if not any_match:
|
||||
return False
|
||||
continue
|
||||
|
||||
# Get the actual value from metadata
|
||||
actual_value = metadata.get(key)
|
||||
|
||||
# Handle different types of expected values
|
||||
if isinstance(expected_value, dict):
|
||||
# Handle comparison operators
|
||||
for op, op_value in expected_value.items():
|
||||
if op == "$eq":
|
||||
if actual_value != op_value:
|
||||
return False
|
||||
elif op == "$ne":
|
||||
if actual_value == op_value:
|
||||
return False
|
||||
elif op == "$in":
|
||||
if (
|
||||
not isinstance(op_value, list)
|
||||
or actual_value not in op_value
|
||||
):
|
||||
return False
|
||||
elif op == "$nin":
|
||||
if isinstance(op_value, list) and actual_value in op_value:
|
||||
return False
|
||||
elif op == "$exists":
|
||||
if bool(op_value) != (key in metadata):
|
||||
return False
|
||||
# Add more operators as needed
|
||||
else:
|
||||
# Simple equality check
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
@ -30,6 +30,10 @@ class Vector:
|
|||
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
||||
|
||||
return PineconeClient()
|
||||
case VectorType.S3VECTOR:
|
||||
from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
|
||||
|
||||
return S3VectorClient()
|
||||
case VectorType.OPENSEARCH:
|
||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
|
|
@ -48,6 +52,10 @@ class Vector:
|
|||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
return ChromaClient()
|
||||
case VectorType.ORACLE23AI:
|
||||
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
|
||||
|
||||
return Oracle23aiClient()
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector type: {vector_type}")
|
||||
|
||||
|
|
|
|||
|
|
@ -9,3 +9,5 @@ class VectorType(StrEnum):
|
|||
ELASTICSEARCH = "elasticsearch"
|
||||
OPENSEARCH = "opensearch"
|
||||
PGVECTOR = "pgvector"
|
||||
ORACLE23AI = "oracle23ai"
|
||||
S3VECTOR = "s3vector"
|
||||
|
|
|
|||
14
backend/open_webui/retrieval/vector/utils.py
Normal file
14
backend/open_webui/retrieval/vector/utils.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from datetime import datetime
|
||||
|
||||
|
||||
def stringify_metadata(
|
||||
metadata: dict[str, any],
|
||||
) -> dict[str, any]:
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, datetime)
|
||||
or isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
):
|
||||
metadata[key] = str(value)
|
||||
return metadata
|
||||
|
|
@ -11,7 +11,10 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|||
|
||||
|
||||
def search_duckduckgo(
|
||||
query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
concurrent_requests: Optional[int] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
|
||||
|
|
@ -25,6 +28,9 @@ def search_duckduckgo(
|
|||
# Use the DDGS context manager to create a DDGS object
|
||||
search_results = []
|
||||
with DDGS() as ddgs:
|
||||
if concurrent_requests:
|
||||
ddgs.threads = concurrent_requests
|
||||
|
||||
# Use the ddgs.text() method to perform the search
|
||||
try:
|
||||
search_results = ddgs.text(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ def get_filtered_results(results, filter_list):
|
|||
return results
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
url = result.get("url") or result.get("link", "")
|
||||
url = result.get("url") or result.get("link", "") or result.get("href", "")
|
||||
if not validators.url(url):
|
||||
continue
|
||||
domain = urlparse(url).netloc
|
||||
|
|
|
|||
|
|
@ -376,9 +376,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
if r is not None:
|
||||
status_code = r.status
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
|
||||
try:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
|
|
@ -546,6 +550,11 @@ def transcription_handler(request, file_path, metadata):
|
|||
|
||||
metadata = metadata or {}
|
||||
|
||||
languages = [
|
||||
metadata.get("language", None) if WHISPER_LANGUAGE == "" else WHISPER_LANGUAGE,
|
||||
None, # Always fallback to None in case transcription fails
|
||||
]
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
if request.app.state.faster_whisper_model is None:
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
|
|
@ -557,7 +566,7 @@ def transcription_handler(request, file_path, metadata):
|
|||
file_path,
|
||||
beam_size=5,
|
||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||
language=metadata.get("language") or WHISPER_LANGUAGE,
|
||||
language=languages[0],
|
||||
)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
|
|
@ -577,21 +586,26 @@ def transcription_handler(request, file_path, metadata):
|
|||
elif request.app.state.config.STT_ENGINE == "openai":
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data={
|
||||
for language in languages:
|
||||
payload = {
|
||||
"model": request.app.state.config.STT_MODEL,
|
||||
**(
|
||||
{"language": metadata.get("language")}
|
||||
if metadata.get("language")
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if language:
|
||||
payload["language"] = language
|
||||
|
||||
r = requests.post(
|
||||
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data=payload,
|
||||
)
|
||||
|
||||
if r.status_code == 200:
|
||||
# Successful transcription
|
||||
break
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
|
@ -633,18 +647,26 @@ def transcription_handler(request, file_path, metadata):
|
|||
"Content-Type": mime,
|
||||
}
|
||||
|
||||
# Add model if specified
|
||||
params = {}
|
||||
if request.app.state.config.STT_MODEL:
|
||||
params["model"] = request.app.state.config.STT_MODEL
|
||||
for language in languages:
|
||||
params = {}
|
||||
if request.app.state.config.STT_MODEL:
|
||||
params["model"] = request.app.state.config.STT_MODEL
|
||||
|
||||
if language:
|
||||
params["language"] = language
|
||||
|
||||
# Make request to Deepgram API
|
||||
r = requests.post(
|
||||
"https://api.deepgram.com/v1/listen?smart_format=true",
|
||||
headers=headers,
|
||||
params=params,
|
||||
data=file_data,
|
||||
)
|
||||
|
||||
if r.status_code == 200:
|
||||
# Successful transcription
|
||||
break
|
||||
|
||||
# Make request to Deepgram API
|
||||
r = requests.post(
|
||||
"https://api.deepgram.com/v1/listen?smart_format=true",
|
||||
headers=headers,
|
||||
params=params,
|
||||
data=file_data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
response_data = r.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,9 @@ from open_webui.models.auths import (
|
|||
SigninResponse,
|
||||
SignupForm,
|
||||
UpdatePasswordForm,
|
||||
UpdateProfileForm,
|
||||
UserResponse,
|
||||
)
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.users import Users, UpdateProfileForm
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
|
|
@ -73,7 +72,13 @@ class SessionUserResponse(Token, UserResponse):
|
|||
permissions: Optional[dict] = None
|
||||
|
||||
|
||||
@router.get("/", response_model=SessionUserResponse)
|
||||
class SessionUserInfoResponse(SessionUserResponse):
|
||||
bio: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
|
||||
|
||||
@router.get("/", response_model=SessionUserInfoResponse)
|
||||
async def get_session_user(
|
||||
request: Request, response: Response, user=Depends(get_current_user)
|
||||
):
|
||||
|
|
@ -121,6 +126,9 @@ async def get_session_user(
|
|||
"name": user.name,
|
||||
"role": user.role,
|
||||
"profile_image_url": user.profile_image_url,
|
||||
"bio": user.bio,
|
||||
"gender": user.gender,
|
||||
"date_of_birth": user.date_of_birth,
|
||||
"permissions": user_permissions,
|
||||
}
|
||||
|
||||
|
|
@ -137,7 +145,7 @@ async def update_profile(
|
|||
if session_user:
|
||||
user = Users.update_user_by_id(
|
||||
session_user.id,
|
||||
{"profile_image_url": form_data.profile_image_url, "name": form_data.name},
|
||||
form_data.model_dump(),
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
|
@ -351,11 +359,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
try:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
role = (
|
||||
"admin"
|
||||
if user_count == 0
|
||||
if not Users.has_users()
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
|
|
@ -489,7 +495,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
if Users.get_user_by_email(admin_email.lower()):
|
||||
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
||||
else:
|
||||
if Users.get_num_users() != 0:
|
||||
if Users.has_users():
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||
|
||||
await signup(
|
||||
|
|
@ -556,6 +562,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
has_users = Users.has_users()
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
|
|
@ -566,12 +573,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
else:
|
||||
if Users.get_num_users() != 0:
|
||||
if has_users:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
|
|
@ -581,9 +587,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
role = (
|
||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
||||
|
||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||
if len(form_data.password.encode("utf-8")) > 72:
|
||||
|
|
@ -629,7 +633,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
)
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
await post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
|
|
@ -644,7 +648,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
if user_count == 0:
|
||||
if not has_users:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
|
||||
|
|
@ -673,7 +677,7 @@ async def signout(request: Request, response: Response):
|
|||
|
||||
if ENABLE_OAUTH_SIGNUP.value:
|
||||
oauth_id_token = request.cookies.get("oauth_id_token")
|
||||
if oauth_id_token:
|
||||
if oauth_id_token and OPENID_PROVIDER_URL.value:
|
||||
try:
|
||||
async with ClientSession(trust_env=True) as session:
|
||||
async with session.get(OPENID_PROVIDER_URL.value) as resp:
|
||||
|
|
|
|||
|
|
@ -209,7 +209,7 @@ async def send_notification(name, webui_url, channel, message, active_user_ids):
|
|||
)
|
||||
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
await post_webhook(
|
||||
name,
|
||||
webhook_url,
|
||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||
|
|
@ -434,13 +434,6 @@ async def update_message_by_id(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
|
|
@ -452,6 +445,15 @@ async def update_message_by_id(
|
|||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and message.user_id != user.id
|
||||
and not has_access(user.id, type="read", access_control=channel.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
message = Messages.update_message_by_id(message_id, form_data)
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
|
|
@ -641,13 +643,6 @@ async def delete_message_by_id(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
|
|
@ -659,6 +654,15 @@ async def delete_message_by_id(
|
|||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and message.user_id != user.id
|
||||
and not has_access(user.id, type="read", access_control=channel.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
Messages.delete_message_by_id(message_id)
|
||||
await sio.emit(
|
||||
|
|
|
|||
|
|
@ -36,16 +36,24 @@ router = APIRouter()
|
|||
|
||||
@router.get("/", response_model=list[ChatTitleIdResponse])
|
||||
@router.get("/list", response_model=list[ChatTitleIdResponse])
|
||||
async def get_session_user_chat_list(
|
||||
def get_session_user_chat_list(
|
||||
user=Depends(get_verified_user), page: Optional[int] = None
|
||||
):
|
||||
if page is not None:
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
try:
|
||||
if page is not None:
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit)
|
||||
else:
|
||||
return Chats.get_chat_title_id_list_by_user_id(user.id)
|
||||
return Chats.get_chat_title_id_list_by_user_id(
|
||||
user.id, skip=skip, limit=limit
|
||||
)
|
||||
else:
|
||||
return Chats.get_chat_title_id_list_by_user_id(user.id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -609,7 +617,18 @@ async def clone_chat_by_id(
|
|||
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
||||
}
|
||||
|
||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||
chat = Chats.import_chat(
|
||||
user.id,
|
||||
ChatImportForm(
|
||||
**{
|
||||
"chat": updated_chat,
|
||||
"meta": chat.meta,
|
||||
"pinned": chat.pinned,
|
||||
"folder_id": chat.folder_id,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -638,7 +657,17 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
"title": f"Clone of {chat.title}",
|
||||
}
|
||||
|
||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||
chat = Chats.import_chat(
|
||||
user.id,
|
||||
ChatImportForm(
|
||||
**{
|
||||
"chat": updated_chat,
|
||||
"meta": chat.meta,
|
||||
"pinned": chat.pinned,
|
||||
"folder_id": chat.folder_id,
|
||||
}
|
||||
),
|
||||
)
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,11 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
|||
from open_webui.config import get_config, save_config
|
||||
from open_webui.config import BannerModel
|
||||
|
||||
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
|
||||
from open_webui.utils.tools import (
|
||||
get_tool_server_data,
|
||||
get_tool_server_url,
|
||||
set_tool_servers,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -110,10 +114,7 @@ async def set_tool_servers_config(
|
|||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||
]
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
await set_tool_servers(request)
|
||||
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
|
|
@ -135,7 +136,7 @@ async def verify_tool_servers_config(
|
|||
elif form_data.auth_type == "session":
|
||||
token = request.state.token.credentials
|
||||
|
||||
url = f"{form_data.url}/{form_data.path}"
|
||||
url = get_tool_server_url(form_data.url, form_data.path)
|
||||
return await get_tool_server_data(token, url)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -129,7 +129,10 @@ async def create_feedback(
|
|||
|
||||
@router.get("/feedback/{id}", response_model=FeedbackModel)
|
||||
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
||||
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
||||
if user.role == "admin":
|
||||
feedback = Feedbacks.get_feedback_by_id(id=id)
|
||||
else:
|
||||
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
||||
|
||||
if not feedback:
|
||||
raise HTTPException(
|
||||
|
|
@ -143,9 +146,12 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
|||
async def update_feedback_by_id(
|
||||
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
|
||||
):
|
||||
feedback = Feedbacks.update_feedback_by_id_and_user_id(
|
||||
id=id, user_id=user.id, form_data=form_data
|
||||
)
|
||||
if user.role == "admin":
|
||||
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data)
|
||||
else:
|
||||
feedback = Feedbacks.update_feedback_by_id_and_user_id(
|
||||
id=id, user_id=user.id, form_data=form_data
|
||||
)
|
||||
|
||||
if not feedback:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from fnmatch import fnmatch
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
import asyncio
|
||||
|
||||
from fastapi import (
|
||||
BackgroundTasks,
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
|
|
@ -18,9 +20,11 @@ from fastapi import (
|
|||
status,
|
||||
Query,
|
||||
)
|
||||
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.files import (
|
||||
|
|
@ -41,7 +45,6 @@ from pydantic import BaseModel
|
|||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
|
@ -82,14 +85,76 @@ def has_access_to_file(
|
|||
############################
|
||||
|
||||
|
||||
def process_uploaded_file(request, file, file_path, file_item, file_metadata, user):
|
||||
try:
|
||||
if file.content_type:
|
||||
stt_supported_content_types = getattr(
|
||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||
)
|
||||
|
||||
if any(
|
||||
fnmatch(file.content_type, content_type)
|
||||
for content_type in (
|
||||
stt_supported_content_types
|
||||
if stt_supported_content_types
|
||||
and any(t.strip() for t in stt_supported_content_types)
|
||||
else ["audio/*", "video/webm"]
|
||||
)
|
||||
):
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path, file_metadata)
|
||||
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(
|
||||
file_id=file_item.id, content=result.get("text", "")
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||
|
||||
Files.update_file_data_by_id(
|
||||
file_item.id,
|
||||
{"status": "completed"},
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
Files.update_file_data_by_id(
|
||||
file_item.id,
|
||||
{
|
||||
"status": "failed",
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=FileModelResponse)
|
||||
def upload_file(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
metadata: Optional[dict | str] = Form(None),
|
||||
process: bool = Query(True),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
return upload_file_handler(request, file, metadata, process, user, background_tasks)
|
||||
|
||||
|
||||
def upload_file_handler(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
metadata: Optional[dict | str] = Form(None),
|
||||
process: bool = Query(True),
|
||||
internal: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
background_tasks: Optional[BackgroundTasks] = None,
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
|
|
@ -111,7 +176,7 @@ def upload_file(
|
|||
# Remove the leading dot from the file extension
|
||||
file_extension = file_extension[1:] if file_extension else ""
|
||||
|
||||
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
if process and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
||||
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
||||
]
|
||||
|
|
@ -128,13 +193,16 @@ def upload_file(
|
|||
id = str(uuid.uuid4())
|
||||
name = filename
|
||||
filename = f"{id}_{filename}"
|
||||
tags = {
|
||||
"OpenWebUI-User-Email": user.email,
|
||||
"OpenWebUI-User-Id": user.id,
|
||||
"OpenWebUI-User-Name": user.name,
|
||||
"OpenWebUI-File-Id": id,
|
||||
}
|
||||
contents, file_path = Storage.upload_file(file.file, filename, tags)
|
||||
contents, file_path = Storage.upload_file(
|
||||
file.file,
|
||||
filename,
|
||||
{
|
||||
"OpenWebUI-User-Email": user.email,
|
||||
"OpenWebUI-User-Id": user.id,
|
||||
"OpenWebUI-User-Name": user.name,
|
||||
"OpenWebUI-File-Id": id,
|
||||
},
|
||||
)
|
||||
|
||||
file_item = Files.insert_new_file(
|
||||
user.id,
|
||||
|
|
@ -143,6 +211,9 @@ def upload_file(
|
|||
"id": id,
|
||||
"filename": name,
|
||||
"path": file_path,
|
||||
"data": {
|
||||
**({"status": "pending"} if process else {}),
|
||||
},
|
||||
"meta": {
|
||||
"name": name,
|
||||
"content_type": file.content_type,
|
||||
|
|
@ -152,58 +223,37 @@ def upload_file(
|
|||
}
|
||||
),
|
||||
)
|
||||
|
||||
if process:
|
||||
try:
|
||||
if file.content_type:
|
||||
stt_supported_content_types = getattr(
|
||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||
)
|
||||
|
||||
if any(
|
||||
fnmatch(file.content_type, content_type)
|
||||
for content_type in (
|
||||
stt_supported_content_types
|
||||
if stt_supported_content_types
|
||||
and any(t.strip() for t in stt_supported_content_types)
|
||||
else ["audio/*", "video/webm"]
|
||||
)
|
||||
):
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path, file_metadata)
|
||||
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
file_item = FileModelResponse(
|
||||
**{
|
||||
**file_item.model_dump(),
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
}
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
process_uploaded_file,
|
||||
request,
|
||||
file,
|
||||
file_path,
|
||||
file_item,
|
||||
file_metadata,
|
||||
user,
|
||||
)
|
||||
|
||||
if file_item:
|
||||
return file_item
|
||||
return {"status": True, **file_item.model_dump()}
|
||||
else:
|
||||
process_uploaded_file(
|
||||
request,
|
||||
file,
|
||||
file_path,
|
||||
file_item,
|
||||
file_metadata,
|
||||
user,
|
||||
)
|
||||
return {"status": True, **file_item.model_dump()}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
||||
)
|
||||
if file_item:
|
||||
return file_item
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -286,6 +336,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
|||
if result:
|
||||
try:
|
||||
Storage.delete_all_files()
|
||||
VECTOR_DB_CLIENT.reset()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error("Error deleting files")
|
||||
|
|
@ -329,6 +380,60 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
|
||||
@router.get("/{id}/process/status")
|
||||
async def get_file_process_status(
|
||||
id: str, stream: bool = Query(False), user=Depends(get_verified_user)
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
if stream:
|
||||
MAX_FILE_PROCESSING_DURATION = 3600 * 2
|
||||
|
||||
async def event_stream(file_item):
|
||||
for _ in range(MAX_FILE_PROCESSING_DURATION):
|
||||
file_item = Files.get_file_by_id(file_item.id)
|
||||
if file_item:
|
||||
data = file_item.model_dump().get("data", {})
|
||||
status = data.get("status")
|
||||
|
||||
if status:
|
||||
event = {"status": status}
|
||||
if status == "failed":
|
||||
event["error"] = data.get("error")
|
||||
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
if status in ("completed", "failed"):
|
||||
break
|
||||
else:
|
||||
# Legacy
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(file),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
return {"status": file.data.get("status", "pending")}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Get File Data Content By Id
|
||||
############################
|
||||
|
|
@ -603,12 +708,12 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||
or user.role == "admin"
|
||||
or has_access_to_file(id, "write", user)
|
||||
):
|
||||
# We should add Chroma cleanup here
|
||||
|
||||
result = Files.delete_file_by_id(id)
|
||||
if result:
|
||||
try:
|
||||
Storage.delete_file(file.path)
|
||||
VECTOR_DB_CLIENT.delete(collection_name=f"file-{id}")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error("Error deleting files")
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ async def get_folders(user=Depends(get_verified_user)):
|
|||
**folder.model_dump(),
|
||||
"items": {
|
||||
"chats": [
|
||||
{"title": chat.title, "id": chat.id}
|
||||
{"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
|
||||
for chat in Chats.get_chats_by_folder_id_and_user_id(
|
||||
folder.id, user.id
|
||||
)
|
||||
|
|
@ -78,7 +78,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
try:
|
||||
folder = Folders.insert_new_folder(user.id, form_data.name)
|
||||
folder = Folders.insert_new_folder(user.id, form_data)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -244,11 +244,11 @@ async def delete_folder_by_id(
|
|||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
try:
|
||||
result = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
raise Exception("Error deleting folder")
|
||||
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||
for folder_id in folder_ids:
|
||||
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting folder: {id}")
|
||||
|
|
|
|||
|
|
@ -131,15 +131,29 @@ async def load_function_from_url(
|
|||
############################
|
||||
|
||||
|
||||
class SyncFunctionsForm(FunctionForm):
|
||||
class SyncFunctionsForm(BaseModel):
|
||||
functions: list[FunctionModel] = []
|
||||
|
||||
|
||||
@router.post("/sync", response_model=Optional[FunctionModel])
|
||||
@router.post("/sync", response_model=list[FunctionModel])
|
||||
async def sync_functions(
|
||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
return Functions.sync_functions(user.id, form_data.functions)
|
||||
try:
|
||||
for function in form_data.functions:
|
||||
function.content = replace_imports(function.content)
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
function.id,
|
||||
content=function.content,
|
||||
)
|
||||
|
||||
return Functions.sync_functions(user.id, form_data.functions)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to load a function: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from open_webui.models.groups import (
|
|||
GroupForm,
|
||||
GroupUpdateForm,
|
||||
GroupResponse,
|
||||
UserIdsForm,
|
||||
)
|
||||
|
||||
from open_webui.config import CACHE_DIR
|
||||
|
|
@ -107,6 +108,56 @@ async def update_group_by_id(
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# AddUserToGroupByUserIdAndGroupId
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
|
||||
async def add_user_to_group(
|
||||
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
if form_data.user_ids:
|
||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
||||
|
||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error adding users to group"),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error adding users to group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
|
||||
async def remove_users_from_group(
|
||||
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error removing users from group"),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error removing users from group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteGroupById
|
||||
############################
|
||||
|
|
|
|||
|
|
@ -10,11 +10,18 @@ from typing import Optional
|
|||
|
||||
from urllib.parse import quote
|
||||
import requests
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
)
|
||||
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||
from open_webui.routers.files import upload_file
|
||||
from open_webui.routers.files import upload_file_handler
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.images.comfyui import (
|
||||
ComfyUIGenerateImageForm,
|
||||
|
|
@ -469,7 +476,13 @@ def upload_image(request, image_data, content_type, metadata, user):
|
|||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
|
||||
file_item = upload_file_handler(
|
||||
request,
|
||||
file=file,
|
||||
metadata=metadata,
|
||||
process=False,
|
||||
user=user,
|
||||
)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
|
|
@ -483,11 +496,18 @@ async def image_generations(
|
|||
# if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
|
||||
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
|
||||
# image model other than gpt-image-1, which is warned about on settings save
|
||||
width, height = (
|
||||
tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
|
||||
if "x" in request.app.state.config.IMAGE_SIZE
|
||||
else (512, 512)
|
||||
)
|
||||
|
||||
size = "512x512"
|
||||
if (
|
||||
request.app.state.config.IMAGE_SIZE
|
||||
and "x" in request.app.state.config.IMAGE_SIZE
|
||||
):
|
||||
size = request.app.state.config.IMAGE_SIZE
|
||||
|
||||
if form_data.size and "x" in form_data.size:
|
||||
size = form_data.size
|
||||
|
||||
width, height = tuple(map(int, size.split("x")))
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from open_webui.utils.access_control import has_access, has_permission
|
|||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||
from open_webui.models.models import Models, ModelForm
|
||||
|
||||
|
||||
|
|
@ -42,7 +43,7 @@ router = APIRouter()
|
|||
async def get_knowledge(user=Depends(get_verified_user)):
|
||||
knowledge_bases = []
|
||||
|
||||
if user.role == "admin":
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
else:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
||||
|
|
@ -90,7 +91,7 @@ async def get_knowledge(user=Depends(get_verified_user)):
|
|||
async def get_knowledge_list(user=Depends(get_verified_user)):
|
||||
knowledge_bases = []
|
||||
|
||||
if user.role == "admin":
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
else:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
|
||||
|
|
|
|||
|
|
@ -82,6 +82,10 @@ class QueryMemoryForm(BaseModel):
|
|||
async def query_memory(
|
||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
||||
):
|
||||
memories = Memories.get_memories_by_user_id(user.id)
|
||||
if not memories:
|
||||
raise HTTPException(status_code=404, detail="No memories found for user")
|
||||
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
||||
|
|
|
|||
|
|
@ -7,13 +7,15 @@ from open_webui.models.models import (
|
|||
ModelUserResponse,
|
||||
Models,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -25,7 +27,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/", response_model=list[ModelUserResponse])
|
||||
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
return Models.get_models()
|
||||
else:
|
||||
return Models.get_models_by_user_id(user.id)
|
||||
|
|
@ -78,6 +80,32 @@ async def create_new_model(
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# ExportModels
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/export", response_model=list[ModelModel])
|
||||
async def export_models(user=Depends(get_admin_user)):
|
||||
return Models.get_models()
|
||||
|
||||
|
||||
############################
|
||||
# SyncModels
|
||||
############################
|
||||
|
||||
|
||||
class SyncModelsForm(BaseModel):
|
||||
models: list[ModelModel] = []
|
||||
|
||||
|
||||
@router.post("/sync", response_model=list[ModelModel])
|
||||
async def sync_models(
|
||||
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
return Models.sync_models(user.id, form_data.models)
|
||||
|
||||
|
||||
###########################
|
||||
# GetModelById
|
||||
###########################
|
||||
|
|
@ -89,7 +117,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
model = Models.get_model_by_id(id)
|
||||
if model:
|
||||
if (
|
||||
user.role == "admin"
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or model.user_id == user.id
|
||||
or has_access(user.id, "read", model.access_control)
|
||||
):
|
||||
|
|
@ -102,7 +130,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
############################
|
||||
# ToggelModelById
|
||||
# ToggleModelById
|
||||
############################
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ from typing import Optional
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.socket.main import sio
|
||||
|
||||
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse
|
||||
|
||||
|
|
@ -170,6 +173,12 @@ async def update_note_by_id(
|
|||
|
||||
try:
|
||||
note = Notes.update_note_by_id(id, form_data)
|
||||
await sio.emit(
|
||||
"note-events",
|
||||
note.model_dump(),
|
||||
to=f"note:{note.id}",
|
||||
)
|
||||
|
||||
return note
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from open_webui.utils.misc import (
|
|||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_ollama,
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
apply_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
|
@ -124,6 +124,7 @@ async def send_post_request(
|
|||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
user: UserModel = None,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
|
||||
r = None
|
||||
|
|
@ -144,6 +145,11 @@ async def send_post_request(
|
|||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
**(
|
||||
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
|
||||
if metadata and metadata.get("chat_id")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
|
|
@ -184,7 +190,6 @@ async def send_post_request(
|
|||
)
|
||||
else:
|
||||
res = await r.json()
|
||||
await cleanup_response(r, session)
|
||||
return res
|
||||
|
||||
except HTTPException as e:
|
||||
|
|
@ -196,6 +201,9 @@ async def send_post_request(
|
|||
status_code=r.status if r else 500,
|
||||
detail=detail if e else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not stream:
|
||||
await cleanup_response(r, session)
|
||||
|
||||
|
||||
def get_api_key(idx, url, configs):
|
||||
|
|
@ -407,15 +415,15 @@ async def get_all_models(request: Request, user: UserModel = None):
|
|||
try:
|
||||
loaded_models = await get_ollama_loaded_models(request, user=user)
|
||||
expires_map = {
|
||||
m["name"]: m["expires_at"]
|
||||
m["model"]: m["expires_at"]
|
||||
for m in loaded_models["models"]
|
||||
if "expires_at" in m
|
||||
}
|
||||
|
||||
for m in models["models"]:
|
||||
if m["name"] in expires_map:
|
||||
if m["model"] in expires_map:
|
||||
# Parse ISO8601 datetime with offset, get unix timestamp as int
|
||||
dt = datetime.fromisoformat(expires_map[m["name"]])
|
||||
dt = datetime.fromisoformat(expires_map[m["model"]])
|
||||
m["expires_at"] = int(dt.timestamp())
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to get loaded models: {e}")
|
||||
|
|
@ -1322,7 +1330,7 @@ async def generate_chat_completion(
|
|||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_ollama(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
|
|
@ -1363,6 +1371,7 @@ async def generate_chat_completion(
|
|||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
content_type="application/x-ndjson",
|
||||
user=user,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1401,6 +1410,8 @@ async def generate_openai_completion(
|
|||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
metadata = form_data.pop("metadata", None)
|
||||
|
||||
try:
|
||||
form_data = OpenAICompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
|
|
@ -1466,6 +1477,7 @@ async def generate_openai_completion(
|
|||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1507,7 +1519,7 @@ async def generate_openai_chat_completion(
|
|||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
|
|
@ -1547,6 +1559,7 @@ async def generate_openai_chat_completion(
|
|||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,17 +2,20 @@ import asyncio
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, overload
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
import requests
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from fastapi import Depends, HTTPException, Request, APIRouter
|
||||
from fastapi.responses import (
|
||||
FileResponse,
|
||||
StreamingResponse,
|
||||
JSONResponse,
|
||||
PlainTextResponse,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
|
|
@ -31,12 +34,12 @@ from open_webui.env import (
|
|||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
apply_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
convert_logit_bias_input_to_json,
|
||||
|
|
@ -95,12 +98,12 @@ async def cleanup_response(
|
|||
await session.close()
|
||||
|
||||
|
||||
def openai_o_series_handler(payload):
|
||||
def openai_reasoning_model_handler(payload):
|
||||
"""
|
||||
Handle "o" series specific parameters
|
||||
Handle reasoning model specific parameters
|
||||
"""
|
||||
if "max_tokens" in payload:
|
||||
# Convert "max_tokens" to "max_completion_tokens" for all o-series models
|
||||
# Convert "max_tokens" to "max_completion_tokens" for all reasoning models
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
|
|
@ -358,11 +361,22 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
|||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
|
||||
for model in (
|
||||
model_list = (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
)
|
||||
if not isinstance(model_list, list):
|
||||
# Catch non-list responses
|
||||
model_list = []
|
||||
|
||||
for model in model_list:
|
||||
# Remove name key if its value is None #16689
|
||||
if "name" in model and model["name"] is None:
|
||||
del model["name"]
|
||||
|
||||
if prefix_id:
|
||||
model["id"] = f"{prefix_id}.{model['id']}"
|
||||
model["id"] = (
|
||||
f"{prefix_id}.{model.get('id', model.get('name', ''))}"
|
||||
)
|
||||
|
||||
if tags:
|
||||
model["tags"] = tags
|
||||
|
|
@ -593,15 +607,21 @@ async def verify_connection(
|
|||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
try:
|
||||
response_data = await r.json()
|
||||
except Exception:
|
||||
response_data = await r.text()
|
||||
|
||||
if r.status != 200:
|
||||
if isinstance(response_data, (dict, list)):
|
||||
return JSONResponse(
|
||||
status_code=r.status, content=response_data
|
||||
)
|
||||
else:
|
||||
return PlainTextResponse(
|
||||
status_code=r.status, content=response_data
|
||||
)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
|
@ -611,15 +631,21 @@ async def verify_connection(
|
|||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
try:
|
||||
response_data = await r.json()
|
||||
except Exception:
|
||||
response_data = await r.text()
|
||||
|
||||
if r.status != 200:
|
||||
if isinstance(response_data, (dict, list)):
|
||||
return JSONResponse(
|
||||
status_code=r.status, content=response_data
|
||||
)
|
||||
else:
|
||||
return PlainTextResponse(
|
||||
status_code=r.status, content=response_data
|
||||
)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
|
|
@ -630,8 +656,9 @@ async def verify_connection(
|
|||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Unexpected error: {e}")
|
||||
error_detail = f"Unexpected error: {str(e)}"
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Open WebUI: Server Connection Error"
|
||||
)
|
||||
|
||||
|
||||
def get_azure_allowed_params(api_version: str) -> set[str]:
|
||||
|
|
@ -675,6 +702,10 @@ def get_azure_allowed_params(api_version: str) -> set[str]:
|
|||
return allowed_params
|
||||
|
||||
|
||||
def is_openai_reasoning_model(model: str) -> bool:
|
||||
return model.lower().startswith(("o1", "o3", "o4", "gpt-5"))
|
||||
|
||||
|
||||
def convert_to_azure_payload(url, payload: dict, api_version: str):
|
||||
model = payload.get("model", "")
|
||||
|
||||
|
|
@ -682,7 +713,7 @@ def convert_to_azure_payload(url, payload: dict, api_version: str):
|
|||
allowed_params = get_azure_allowed_params(api_version)
|
||||
|
||||
# Special handling for o-series models
|
||||
if model.startswith("o") and model.endswith("-mini"):
|
||||
if is_openai_reasoning_model(model):
|
||||
# Convert max_tokens to max_completion_tokens for o-series models
|
||||
if "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
|
|
@ -732,7 +763,7 @@ async def generate_chat_completion(
|
|||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
|
|
@ -787,10 +818,9 @@ async def generate_chat_completion(
|
|||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Check if model is from "o" series
|
||||
is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
|
||||
if is_o_series:
|
||||
payload = openai_o_series_handler(payload)
|
||||
# Check if model is a reasoning model that needs special handling
|
||||
if is_openai_reasoning_model(payload["model"]):
|
||||
payload = openai_reasoning_model_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
|
|
@ -822,6 +852,11 @@ async def generate_chat_completion(
|
|||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
**(
|
||||
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
|
||||
if metadata and metadata.get("chat_id")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
|
|
@ -876,27 +911,23 @@ async def generate_chat_completion(
|
|||
log.error(e)
|
||||
response = await r.text()
|
||||
|
||||
r.raise_for_status()
|
||||
if r.status >= 400:
|
||||
if isinstance(response, (dict, list)):
|
||||
return JSONResponse(status_code=r.status, content=response)
|
||||
else:
|
||||
return PlainTextResponse(status_code=r.status, content=response)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
detail = None
|
||||
if isinstance(response, dict):
|
||||
if "error" in response:
|
||||
detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
|
||||
elif isinstance(response, str):
|
||||
detail = response
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
detail="Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
if not streaming:
|
||||
await cleanup_response(r, session)
|
||||
|
||||
|
||||
async def embeddings(request: Request, form_data: dict, user):
|
||||
|
|
@ -946,7 +977,7 @@ async def embeddings(request: Request, form_data: dict, user):
|
|||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
|
|
@ -958,27 +989,29 @@ async def embeddings(request: Request, form_data: dict, user):
|
|||
),
|
||||
)
|
||||
else:
|
||||
response_data = await r.json()
|
||||
try:
|
||||
response_data = await r.json()
|
||||
except Exception:
|
||||
response_data = await r.text()
|
||||
|
||||
if r.status >= 400:
|
||||
if isinstance(response_data, (dict, list)):
|
||||
return JSONResponse(status_code=r.status, content=response_data)
|
||||
else:
|
||||
return PlainTextResponse(
|
||||
status_code=r.status, content=response_data
|
||||
)
|
||||
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
detail="Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
if not streaming:
|
||||
await cleanup_response(r, session)
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
|
|
@ -1040,7 +1073,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
|
|
@ -1054,27 +1086,27 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
),
|
||||
)
|
||||
else:
|
||||
response_data = await r.json()
|
||||
try:
|
||||
response_data = await r.json()
|
||||
except Exception:
|
||||
response_data = await r.text()
|
||||
|
||||
if r.status >= 400:
|
||||
if isinstance(response_data, (dict, list)):
|
||||
return JSONResponse(status_code=r.status, content=response_data)
|
||||
else:
|
||||
return PlainTextResponse(
|
||||
status_code=r.status, content=response_data
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
log.error(res)
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
detail="Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
if not streaming:
|
||||
await cleanup_response(r, session)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
|
||||
from open_webui.models.prompts import (
|
||||
PromptForm,
|
||||
|
|
@ -7,9 +8,9 @@ from open_webui.models.prompts import (
|
|||
Prompts,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/", response_model=list[PromptModel])
|
||||
async def get_prompts(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
prompts = Prompts.get_prompts()
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
||||
|
|
@ -30,7 +31,7 @@ async def get_prompts(user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/list", response_model=list[PromptUserResponse])
|
||||
async def get_prompt_list(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
prompts = Prompts.get_prompts()
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import os
|
|||
import shutil
|
||||
import asyncio
|
||||
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
|
@ -281,6 +280,18 @@ async def update_embedding_config(
|
|||
log.info(
|
||||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "":
|
||||
# unloads current internal embedding model and clears VRAM cache
|
||||
request.app.state.ef = None
|
||||
request.app.state.EMBEDDING_FUNCTION = None
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if DEVICE_TYPE == "cuda":
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
|
@ -401,12 +412,14 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
||||
"DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
||||
"DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||||
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||||
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||
"DATALAB_MARKER_FORMAT_LINES": request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
|
||||
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
|
|
@ -447,6 +460,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||
"WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
"WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"WEB_LOADER_CONCURRENT_REQUESTS": request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
||||
|
|
@ -502,6 +516,7 @@ class WebConfig(BaseModel):
|
|||
WEB_SEARCH_TRUST_ENV: Optional[bool] = None
|
||||
WEB_SEARCH_RESULT_COUNT: Optional[int] = None
|
||||
WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None
|
||||
WEB_LOADER_CONCURRENT_REQUESTS: Optional[int] = None
|
||||
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
|
||||
|
|
@ -566,12 +581,14 @@ class ConfigForm(BaseModel):
|
|||
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
||||
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
||||
DATALAB_MARKER_API_KEY: Optional[str] = None
|
||||
DATALAB_MARKER_LANGS: Optional[str] = None
|
||||
DATALAB_MARKER_API_BASE_URL: Optional[str] = None
|
||||
DATALAB_MARKER_ADDITIONAL_CONFIG: Optional[str] = None
|
||||
DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None
|
||||
DATALAB_MARKER_FORCE_OCR: Optional[bool] = None
|
||||
DATALAB_MARKER_PAGINATE: Optional[bool] = None
|
||||
DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None
|
||||
DATALAB_MARKER_FORMAT_LINES: Optional[bool] = None
|
||||
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
||||
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
||||
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
||||
|
|
@ -647,9 +664,6 @@ async def update_rag_config(
|
|||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
)
|
||||
# Free up memory if hybrid search is disabled
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
request.app.state.rf = None
|
||||
|
||||
request.app.state.config.TOP_K_RERANKER = (
|
||||
form_data.TOP_K_RERANKER
|
||||
|
|
@ -683,10 +697,15 @@ async def update_rag_config(
|
|||
if form_data.DATALAB_MARKER_API_KEY is not None
|
||||
else request.app.state.config.DATALAB_MARKER_API_KEY
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_LANGS = (
|
||||
form_data.DATALAB_MARKER_LANGS
|
||||
if form_data.DATALAB_MARKER_LANGS is not None
|
||||
else request.app.state.config.DATALAB_MARKER_LANGS
|
||||
request.app.state.config.DATALAB_MARKER_API_BASE_URL = (
|
||||
form_data.DATALAB_MARKER_API_BASE_URL
|
||||
if form_data.DATALAB_MARKER_API_BASE_URL is not None
|
||||
else request.app.state.config.DATALAB_MARKER_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG = (
|
||||
form_data.DATALAB_MARKER_ADDITIONAL_CONFIG
|
||||
if form_data.DATALAB_MARKER_ADDITIONAL_CONFIG is not None
|
||||
else request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_SKIP_CACHE = (
|
||||
form_data.DATALAB_MARKER_SKIP_CACHE
|
||||
|
|
@ -713,6 +732,11 @@ async def update_rag_config(
|
|||
if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None
|
||||
else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_FORMAT_LINES = (
|
||||
form_data.DATALAB_MARKER_FORMAT_LINES
|
||||
if form_data.DATALAB_MARKER_FORMAT_LINES is not None
|
||||
else request.app.state.config.DATALAB_MARKER_FORMAT_LINES
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = (
|
||||
form_data.DATALAB_MARKER_OUTPUT_FORMAT
|
||||
if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None
|
||||
|
|
@ -793,6 +817,18 @@ async def update_rag_config(
|
|||
)
|
||||
|
||||
# Reranking settings
|
||||
if request.app.state.config.RAG_RERANKING_ENGINE == "":
|
||||
# Unloading the internal reranker and clear VRAM memory
|
||||
request.app.state.rf = None
|
||||
request.app.state.RERANKING_FUNCTION = None
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if DEVICE_TYPE == "cuda":
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
request.app.state.config.RAG_RERANKING_ENGINE = (
|
||||
form_data.RAG_RERANKING_ENGINE
|
||||
if form_data.RAG_RERANKING_ENGINE is not None
|
||||
|
|
@ -815,22 +851,30 @@ async def update_rag_config(
|
|||
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
||||
)
|
||||
try:
|
||||
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
|
||||
request.app.state.config.RAG_RERANKING_MODEL = (
|
||||
form_data.RAG_RERANKING_MODEL
|
||||
if form_data.RAG_RERANKING_MODEL is not None
|
||||
else request.app.state.config.RAG_RERANKING_MODEL
|
||||
)
|
||||
|
||||
try:
|
||||
request.app.state.rf = get_rf(
|
||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
True,
|
||||
)
|
||||
if (
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
request.app.state.rf = get_rf(
|
||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
True,
|
||||
)
|
||||
|
||||
request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
request.app.state.rf,
|
||||
)
|
||||
request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
request.app.state.rf,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
|
|
@ -898,6 +942,9 @@ async def update_rag_config(
|
|||
request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||
form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
)
|
||||
request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = (
|
||||
form_data.web.WEB_LOADER_CONCURRENT_REQUESTS
|
||||
)
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
||||
form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
)
|
||||
|
|
@ -1002,7 +1049,8 @@ async def update_rag_config(
|
|||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
||||
"DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
||||
"DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||||
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||||
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||||
|
|
@ -1048,6 +1096,7 @@ async def update_rag_config(
|
|||
"WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
"WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"WEB_LOADER_CONCURRENT_REQUESTS": request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
||||
|
|
@ -1225,27 +1274,14 @@ def save_docs_to_vector_db(
|
|||
{
|
||||
**doc.metadata,
|
||||
**(metadata if metadata else {}),
|
||||
"embedding_config": json.dumps(
|
||||
{
|
||||
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
),
|
||||
"embedding_config": {
|
||||
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
},
|
||||
}
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
# ChromaDB does not like datetime formats
|
||||
# for meta-data so convert them to string.
|
||||
for metadata in metadatas:
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, datetime)
|
||||
or isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
):
|
||||
metadata[key] = str(value)
|
||||
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
||||
log.info(f"collection {collection_name} already exists")
|
||||
|
|
@ -1402,12 +1438,14 @@ def process_file(
|
|||
loader = Loader(
|
||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||
DATALAB_MARKER_LANGS=request.app.state.config.DATALAB_MARKER_LANGS,
|
||||
DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
||||
DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||
DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||||
DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||||
DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||||
DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||
DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
|
||||
DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||||
DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
|
|
@ -1462,7 +1500,7 @@ def process_file(
|
|||
log.debug(f"text_content: {text_content}")
|
||||
Files.update_file_data_by_id(
|
||||
file.id,
|
||||
{"content": text_content},
|
||||
{"status": "completed", "content": text_content},
|
||||
)
|
||||
|
||||
hash = calculate_sha256_string(text_content)
|
||||
|
|
@ -1616,7 +1654,7 @@ def process_web(
|
|||
loader = get_web_loader(
|
||||
form_data.url,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
)
|
||||
docs = loader.load()
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
|
|
@ -1781,7 +1819,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|||
request.app.state.config.SERPLY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
filter_list=request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPLY_API_KEY found in environment variables")
|
||||
|
|
@ -1790,6 +1828,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
concurrent_requests=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
)
|
||||
elif engine == "tavily":
|
||||
if request.app.state.config.TAVILY_API_KEY:
|
||||
|
|
@ -1957,13 +1996,13 @@ async def process_web_search(
|
|||
},
|
||||
)
|
||||
for result in search_results
|
||||
if hasattr(result, "snippet")
|
||||
if hasattr(result, "snippet") and result.snippet is not None
|
||||
]
|
||||
else:
|
||||
loader = get_web_loader(
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = await loader.aload()
|
||||
|
|
@ -2050,11 +2089,13 @@ def query_doc_handler(
|
|||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=(
|
||||
lambda sentences: (
|
||||
request.app.state.RERANKING_FUNCTION(sentences, user=user)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
else None
|
||||
(
|
||||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
)
|
||||
)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
else None
|
||||
),
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
|
|
@ -2112,8 +2153,14 @@ def query_collection_handler(
|
|||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
reranking_function=(
|
||||
(
|
||||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
)
|
||||
)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
else None
|
||||
),
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
|
|
|
|||
926
backend/open_webui/routers/scim.py
Normal file
926
backend/open_webui/routers/scim.py
Normal file
|
|
@ -0,0 +1,926 @@
|
|||
"""
|
||||
Experimental SCIM 2.0 Implementation for Open WebUI
|
||||
Provides System for Cross-domain Identity Management endpoints for users and groups
|
||||
|
||||
NOTE: This is an experimental implementation and may not fully comply with SCIM 2.0 standards, and is subject to change.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Query, Header, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from open_webui.models.groups import Groups, GroupModel
|
||||
from open_webui.utils.auth import (
|
||||
get_admin_user,
|
||||
get_current_user,
|
||||
decode_token,
|
||||
get_verified_user,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# SCIM 2.0 Schema URIs
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
|
||||
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
|
||||
|
||||
# SCIM Resource Types
|
||||
SCIM_RESOURCE_TYPE_USER = "User"
|
||||
SCIM_RESOURCE_TYPE_GROUP = "Group"
|
||||
|
||||
|
||||
def scim_error(status_code: int, detail: str, scim_type: Optional[str] = None):
|
||||
"""Create a SCIM-compliant error response"""
|
||||
error_body = {
|
||||
"schemas": [SCIM_ERROR_SCHEMA],
|
||||
"status": str(status_code),
|
||||
"detail": detail,
|
||||
}
|
||||
|
||||
if scim_type:
|
||||
error_body["scimType"] = scim_type
|
||||
elif status_code == 404:
|
||||
error_body["scimType"] = "invalidValue"
|
||||
elif status_code == 409:
|
||||
error_body["scimType"] = "uniqueness"
|
||||
elif status_code == 400:
|
||||
error_body["scimType"] = "invalidSyntax"
|
||||
|
||||
return JSONResponse(status_code=status_code, content=error_body)
|
||||
|
||||
|
||||
class SCIMError(BaseModel):
|
||||
"""SCIM Error Response"""
|
||||
|
||||
schemas: List[str] = [SCIM_ERROR_SCHEMA]
|
||||
status: str
|
||||
scimType: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
class SCIMMeta(BaseModel):
|
||||
"""SCIM Resource Metadata"""
|
||||
|
||||
resourceType: str
|
||||
created: str
|
||||
lastModified: str
|
||||
location: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
|
||||
|
||||
class SCIMName(BaseModel):
|
||||
"""SCIM User Name"""
|
||||
|
||||
formatted: Optional[str] = None
|
||||
familyName: Optional[str] = None
|
||||
givenName: Optional[str] = None
|
||||
middleName: Optional[str] = None
|
||||
honorificPrefix: Optional[str] = None
|
||||
honorificSuffix: Optional[str] = None
|
||||
|
||||
|
||||
class SCIMEmail(BaseModel):
|
||||
"""SCIM Email"""
|
||||
|
||||
value: str
|
||||
type: Optional[str] = "work"
|
||||
primary: bool = True
|
||||
display: Optional[str] = None
|
||||
|
||||
|
||||
class SCIMPhoto(BaseModel):
|
||||
"""SCIM Photo"""
|
||||
|
||||
value: str
|
||||
type: Optional[str] = "photo"
|
||||
primary: bool = True
|
||||
display: Optional[str] = None
|
||||
|
||||
|
||||
class SCIMGroupMember(BaseModel):
|
||||
"""SCIM Group Member"""
|
||||
|
||||
value: str # User ID
|
||||
ref: Optional[str] = Field(None, alias="$ref")
|
||||
type: Optional[str] = "User"
|
||||
display: Optional[str] = None
|
||||
|
||||
|
||||
class SCIMUser(BaseModel):
|
||||
"""SCIM User Resource"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: List[str] = [SCIM_USER_SCHEMA]
|
||||
id: str
|
||||
externalId: Optional[str] = None
|
||||
userName: str
|
||||
name: Optional[SCIMName] = None
|
||||
displayName: str
|
||||
emails: List[SCIMEmail]
|
||||
active: bool = True
|
||||
photos: Optional[List[SCIMPhoto]] = None
|
||||
groups: Optional[List[Dict[str, str]]] = None
|
||||
meta: SCIMMeta
|
||||
|
||||
|
||||
class SCIMUserCreateRequest(BaseModel):
|
||||
"""SCIM User Create Request"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: List[str] = [SCIM_USER_SCHEMA]
|
||||
externalId: Optional[str] = None
|
||||
userName: str
|
||||
name: Optional[SCIMName] = None
|
||||
displayName: str
|
||||
emails: List[SCIMEmail]
|
||||
active: bool = True
|
||||
password: Optional[str] = None
|
||||
photos: Optional[List[SCIMPhoto]] = None
|
||||
|
||||
|
||||
class SCIMUserUpdateRequest(BaseModel):
|
||||
"""SCIM User Update Request"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: List[str] = [SCIM_USER_SCHEMA]
|
||||
id: Optional[str] = None
|
||||
externalId: Optional[str] = None
|
||||
userName: Optional[str] = None
|
||||
name: Optional[SCIMName] = None
|
||||
displayName: Optional[str] = None
|
||||
emails: Optional[List[SCIMEmail]] = None
|
||||
active: Optional[bool] = None
|
||||
photos: Optional[List[SCIMPhoto]] = None
|
||||
|
||||
|
||||
class SCIMGroup(BaseModel):
|
||||
"""SCIM Group Resource"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: List[str] = [SCIM_GROUP_SCHEMA]
|
||||
id: str
|
||||
displayName: str
|
||||
members: Optional[List[SCIMGroupMember]] = []
|
||||
meta: SCIMMeta
|
||||
|
||||
|
||||
class SCIMGroupCreateRequest(BaseModel):
|
||||
"""SCIM Group Create Request"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: List[str] = [SCIM_GROUP_SCHEMA]
|
||||
displayName: str
|
||||
members: Optional[List[SCIMGroupMember]] = []
|
||||
|
||||
|
||||
class SCIMGroupUpdateRequest(BaseModel):
|
||||
"""SCIM Group Update Request"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: List[str] = [SCIM_GROUP_SCHEMA]
|
||||
displayName: Optional[str] = None
|
||||
members: Optional[List[SCIMGroupMember]] = None
|
||||
|
||||
|
||||
class SCIMListResponse(BaseModel):
|
||||
"""SCIM List Response"""
|
||||
|
||||
schemas: List[str] = [SCIM_LIST_RESPONSE_SCHEMA]
|
||||
totalResults: int
|
||||
itemsPerPage: int
|
||||
startIndex: int
|
||||
Resources: List[Any]
|
||||
|
||||
|
||||
class SCIMPatchOperation(BaseModel):
|
||||
"""SCIM Patch Operation"""
|
||||
|
||||
op: str # "add", "replace", "remove"
|
||||
path: Optional[str] = None
|
||||
value: Optional[Any] = None
|
||||
|
||||
|
||||
class SCIMPatchRequest(BaseModel):
|
||||
"""SCIM Patch Request"""
|
||||
|
||||
schemas: List[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]
|
||||
Operations: List[SCIMPatchOperation]
|
||||
|
||||
|
||||
def get_scim_auth(
|
||||
request: Request, authorization: Optional[str] = Header(None)
|
||||
) -> bool:
|
||||
"""
|
||||
Verify SCIM authentication
|
||||
Checks for SCIM-specific bearer token configured in the system
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
parts = authorization.split()
|
||||
if len(parts) != 2:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authorization format. Expected: Bearer <token>",
|
||||
)
|
||||
|
||||
scheme, token = parts
|
||||
if scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication scheme",
|
||||
)
|
||||
|
||||
# Check if SCIM is enabled
|
||||
scim_enabled = getattr(request.app.state, "SCIM_ENABLED", False)
|
||||
log.info(
|
||||
f"SCIM auth check - raw SCIM_ENABLED: {scim_enabled}, type: {type(scim_enabled)}"
|
||||
)
|
||||
# Handle both PersistentConfig and direct value
|
||||
if hasattr(scim_enabled, "value"):
|
||||
scim_enabled = scim_enabled.value
|
||||
log.info(f"SCIM enabled status after conversion: {scim_enabled}")
|
||||
if not scim_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="SCIM is not enabled",
|
||||
)
|
||||
|
||||
# Verify the SCIM token
|
||||
scim_token = getattr(request.app.state, "SCIM_TOKEN", None)
|
||||
# Handle both PersistentConfig and direct value
|
||||
if hasattr(scim_token, "value"):
|
||||
scim_token = scim_token.value
|
||||
log.debug(f"SCIM token configured: {bool(scim_token)}")
|
||||
if not scim_token or token != scim_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid SCIM token",
|
||||
)
|
||||
|
||||
return True
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"SCIM authentication error: {e}")
|
||||
import traceback
|
||||
|
||||
log.error(f"Traceback: {traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication failed",
|
||||
)
|
||||
|
||||
|
||||
def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
||||
"""Convert internal User model to SCIM User"""
|
||||
# Parse display name into name components
|
||||
name_parts = user.name.split(" ", 1) if user.name else ["", ""]
|
||||
given_name = name_parts[0] if name_parts else ""
|
||||
family_name = name_parts[1] if len(name_parts) > 1 else ""
|
||||
|
||||
# Get user's groups
|
||||
user_groups = Groups.get_groups_by_member_id(user.id)
|
||||
groups = [
|
||||
{
|
||||
"value": group.id,
|
||||
"display": group.name,
|
||||
"$ref": f"{request.base_url}api/v1/scim/v2/Groups/{group.id}",
|
||||
"type": "direct",
|
||||
}
|
||||
for group in user_groups
|
||||
]
|
||||
|
||||
return SCIMUser(
|
||||
id=user.id,
|
||||
userName=user.email,
|
||||
name=SCIMName(
|
||||
formatted=user.name,
|
||||
givenName=given_name,
|
||||
familyName=family_name,
|
||||
),
|
||||
displayName=user.name,
|
||||
emails=[SCIMEmail(value=user.email)],
|
||||
active=user.role != "pending",
|
||||
photos=(
|
||||
[SCIMPhoto(value=user.profile_image_url)]
|
||||
if user.profile_image_url
|
||||
else None
|
||||
),
|
||||
groups=groups if groups else None,
|
||||
meta=SCIMMeta(
|
||||
resourceType=SCIM_RESOURCE_TYPE_USER,
|
||||
created=datetime.fromtimestamp(
|
||||
user.created_at, tz=timezone.utc
|
||||
).isoformat(),
|
||||
lastModified=datetime.fromtimestamp(
|
||||
user.updated_at, tz=timezone.utc
|
||||
).isoformat(),
|
||||
location=f"{request.base_url}api/v1/scim/v2/Users/{user.id}",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
||||
"""Convert internal Group model to SCIM Group"""
|
||||
members = []
|
||||
for user_id in group.user_ids:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
if user:
|
||||
members.append(
|
||||
SCIMGroupMember(
|
||||
value=user.id,
|
||||
ref=f"{request.base_url}api/v1/scim/v2/Users/{user.id}",
|
||||
display=user.name,
|
||||
)
|
||||
)
|
||||
|
||||
return SCIMGroup(
|
||||
id=group.id,
|
||||
displayName=group.name,
|
||||
members=members,
|
||||
meta=SCIMMeta(
|
||||
resourceType=SCIM_RESOURCE_TYPE_GROUP,
|
||||
created=datetime.fromtimestamp(
|
||||
group.created_at, tz=timezone.utc
|
||||
).isoformat(),
|
||||
lastModified=datetime.fromtimestamp(
|
||||
group.updated_at, tz=timezone.utc
|
||||
).isoformat(),
|
||||
location=f"{request.base_url}api/v1/scim/v2/Groups/{group.id}",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# SCIM Service Provider Config
|
||||
@router.get("/ServiceProviderConfig")
|
||||
async def get_service_provider_config():
|
||||
"""Get SCIM Service Provider Configuration"""
|
||||
return {
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
|
||||
"patch": {"supported": True},
|
||||
"bulk": {"supported": False, "maxOperations": 1000, "maxPayloadSize": 1048576},
|
||||
"filter": {"supported": True, "maxResults": 200},
|
||||
"changePassword": {"supported": False},
|
||||
"sort": {"supported": False},
|
||||
"etag": {"supported": False},
|
||||
"authenticationSchemes": [
|
||||
{
|
||||
"type": "oauthbearertoken",
|
||||
"name": "OAuth Bearer Token",
|
||||
"description": "Authentication using OAuth 2.0 Bearer Token",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# SCIM Resource Types
|
||||
@router.get("/ResourceTypes")
|
||||
async def get_resource_types(request: Request):
|
||||
"""Get SCIM Resource Types"""
|
||||
return [
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||
"id": "User",
|
||||
"name": "User",
|
||||
"endpoint": "/Users",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"meta": {
|
||||
"location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/User",
|
||||
"resourceType": "ResourceType",
|
||||
},
|
||||
},
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||
"id": "Group",
|
||||
"name": "Group",
|
||||
"endpoint": "/Groups",
|
||||
"schema": SCIM_GROUP_SCHEMA,
|
||||
"meta": {
|
||||
"location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/Group",
|
||||
"resourceType": "ResourceType",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# SCIM Schemas
|
||||
@router.get("/Schemas")
|
||||
async def get_schemas():
|
||||
"""Get SCIM Schemas"""
|
||||
return [
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
|
||||
"id": SCIM_USER_SCHEMA,
|
||||
"name": "User",
|
||||
"description": "User Account",
|
||||
"attributes": [
|
||||
{
|
||||
"name": "userName",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"uniqueness": "server",
|
||||
},
|
||||
{"name": "displayName", "type": "string", "required": True},
|
||||
{
|
||||
"name": "emails",
|
||||
"type": "complex",
|
||||
"multiValued": True,
|
||||
"required": True,
|
||||
},
|
||||
{"name": "active", "type": "boolean", "required": False},
|
||||
],
|
||||
},
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
|
||||
"id": SCIM_GROUP_SCHEMA,
|
||||
"name": "Group",
|
||||
"description": "Group",
|
||||
"attributes": [
|
||||
{"name": "displayName", "type": "string", "required": True},
|
||||
{
|
||||
"name": "members",
|
||||
"type": "complex",
|
||||
"multiValued": True,
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Users endpoints
|
||||
@router.get("/Users", response_model=SCIMListResponse)
|
||||
async def get_users(
|
||||
request: Request,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(20, ge=1, le=100),
|
||||
filter: Optional[str] = None,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""List SCIM Users"""
|
||||
skip = startIndex - 1
|
||||
limit = count
|
||||
|
||||
# Get users from database
|
||||
if filter:
|
||||
# Simple filter parsing - supports userName eq "email"
|
||||
# In production, you'd want a more robust filter parser
|
||||
if "userName eq" in filter:
|
||||
email = filter.split('"')[1]
|
||||
user = Users.get_user_by_email(email)
|
||||
users_list = [user] if user else []
|
||||
total = 1 if user else 0
|
||||
else:
|
||||
response = Users.get_users(skip=skip, limit=limit)
|
||||
users_list = response["users"]
|
||||
total = response["total"]
|
||||
else:
|
||||
response = Users.get_users(skip=skip, limit=limit)
|
||||
users_list = response["users"]
|
||||
total = response["total"]
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_users = [user_to_scim(user, request) for user in users_list]
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total,
|
||||
itemsPerPage=len(scim_users),
|
||||
startIndex=startIndex,
|
||||
Resources=scim_users,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/Users/{user_id}", response_model=SCIMUser)
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Get SCIM User by ID"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return scim_error(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
|
||||
)
|
||||
|
||||
return user_to_scim(user, request)
|
||||
|
||||
|
||||
@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
request: Request,
|
||||
user_data: SCIMUserCreateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Create SCIM User"""
|
||||
# Check if user already exists
|
||||
existing_user = Users.get_user_by_email(user_data.userName)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"User with email {user_data.userName} already exists",
|
||||
)
|
||||
|
||||
# Create user
|
||||
user_id = str(uuid.uuid4())
|
||||
email = user_data.emails[0].value if user_data.emails else user_data.userName
|
||||
|
||||
# Parse name if provided
|
||||
name = user_data.displayName
|
||||
if user_data.name:
|
||||
if user_data.name.formatted:
|
||||
name = user_data.name.formatted
|
||||
elif user_data.name.givenName or user_data.name.familyName:
|
||||
name = f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip()
|
||||
|
||||
# Get profile image if provided
|
||||
profile_image = "/user.png"
|
||||
if user_data.photos and len(user_data.photos) > 0:
|
||||
profile_image = user_data.photos[0].value
|
||||
|
||||
# Create user
|
||||
new_user = Users.insert_new_user(
|
||||
id=user_id,
|
||||
name=name,
|
||||
email=email,
|
||||
profile_image_url=profile_image,
|
||||
role="user" if user_data.active else "pending",
|
||||
)
|
||||
|
||||
if not new_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create user",
|
||||
)
|
||||
|
||||
return user_to_scim(new_user, request)
|
||||
|
||||
|
||||
@router.put("/Users/{user_id}", response_model=SCIMUser)
|
||||
async def update_user(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
user_data: SCIMUserUpdateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Update SCIM User (full update)"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {user_id} not found",
|
||||
)
|
||||
|
||||
# Build update dict
|
||||
update_data = {}
|
||||
|
||||
if user_data.userName:
|
||||
update_data["email"] = user_data.userName
|
||||
|
||||
if user_data.displayName:
|
||||
update_data["name"] = user_data.displayName
|
||||
elif user_data.name:
|
||||
if user_data.name.formatted:
|
||||
update_data["name"] = user_data.name.formatted
|
||||
elif user_data.name.givenName or user_data.name.familyName:
|
||||
update_data["name"] = (
|
||||
f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip()
|
||||
)
|
||||
|
||||
if user_data.emails and len(user_data.emails) > 0:
|
||||
update_data["email"] = user_data.emails[0].value
|
||||
|
||||
if user_data.active is not None:
|
||||
update_data["role"] = "user" if user_data.active else "pending"
|
||||
|
||||
if user_data.photos and len(user_data.photos) > 0:
|
||||
update_data["profile_image_url"] = user_data.photos[0].value
|
||||
|
||||
# Update user
|
||||
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user",
|
||||
)
|
||||
|
||||
return user_to_scim(updated_user, request)
|
||||
|
||||
|
||||
@router.patch("/Users/{user_id}", response_model=SCIMUser)
|
||||
async def patch_user(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
patch_data: SCIMPatchRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Update SCIM User (partial update)"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {user_id} not found",
|
||||
)
|
||||
|
||||
update_data = {}
|
||||
|
||||
for operation in patch_data.Operations:
|
||||
op = operation.op.lower()
|
||||
path = operation.path
|
||||
value = operation.value
|
||||
|
||||
if op == "replace":
|
||||
if path == "active":
|
||||
update_data["role"] = "user" if value else "pending"
|
||||
elif path == "userName":
|
||||
update_data["email"] = value
|
||||
elif path == "displayName":
|
||||
update_data["name"] = value
|
||||
elif path == "emails[primary eq true].value":
|
||||
update_data["email"] = value
|
||||
elif path == "name.formatted":
|
||||
update_data["name"] = value
|
||||
|
||||
# Update user
|
||||
if update_data:
|
||||
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user",
|
||||
)
|
||||
else:
|
||||
updated_user = user
|
||||
|
||||
return user_to_scim(updated_user, request)
|
||||
|
||||
|
||||
@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Delete SCIM User"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {user_id} not found",
|
||||
)
|
||||
|
||||
success = Users.delete_user_by_id(user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete user",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Groups endpoints
|
||||
@router.get("/Groups", response_model=SCIMListResponse)
|
||||
async def get_groups(
|
||||
request: Request,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(20, ge=1, le=100),
|
||||
filter: Optional[str] = None,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""List SCIM Groups"""
|
||||
# Get all groups
|
||||
groups_list = Groups.get_groups()
|
||||
|
||||
# Apply pagination
|
||||
total = len(groups_list)
|
||||
start = startIndex - 1
|
||||
end = start + count
|
||||
paginated_groups = groups_list[start:end]
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_groups = [group_to_scim(group, request) for group in paginated_groups]
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total,
|
||||
itemsPerPage=len(scim_groups),
|
||||
startIndex=startIndex,
|
||||
Resources=scim_groups,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/Groups/{group_id}", response_model=SCIMGroup)
|
||||
async def get_group(
|
||||
group_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Get SCIM Group by ID"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
return group_to_scim(group, request)
|
||||
|
||||
|
||||
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
|
||||
async def create_group(
|
||||
request: Request,
|
||||
group_data: SCIMGroupCreateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Create SCIM Group"""
|
||||
# Extract member IDs
|
||||
member_ids = []
|
||||
if group_data.members:
|
||||
for member in group_data.members:
|
||||
member_ids.append(member.value)
|
||||
|
||||
# Create group
|
||||
from open_webui.models.groups import GroupForm
|
||||
|
||||
form = GroupForm(
|
||||
name=group_data.displayName,
|
||||
description="",
|
||||
)
|
||||
|
||||
# Need to get the creating user's ID - we'll use the first admin
|
||||
admin_user = Users.get_super_admin_user()
|
||||
if not admin_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="No admin user found",
|
||||
)
|
||||
|
||||
new_group = Groups.insert_new_group(admin_user.id, form)
|
||||
if not new_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create group",
|
||||
)
|
||||
|
||||
# Add members if provided
|
||||
if member_ids:
|
||||
from open_webui.models.groups import GroupUpdateForm
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=new_group.name,
|
||||
description=new_group.description,
|
||||
user_ids=member_ids,
|
||||
)
|
||||
Groups.update_group_by_id(new_group.id, update_form)
|
||||
new_group = Groups.get_group_by_id(new_group.id)
|
||||
|
||||
return group_to_scim(new_group, request)
|
||||
|
||||
|
||||
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
|
||||
async def update_group(
|
||||
group_id: str,
|
||||
request: Request,
|
||||
group_data: SCIMGroupUpdateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Update SCIM Group (full update)"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
# Build update form
|
||||
from open_webui.models.groups import GroupUpdateForm
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=group_data.displayName if group_data.displayName else group.name,
|
||||
description=group.description,
|
||||
)
|
||||
|
||||
# Handle members if provided
|
||||
if group_data.members is not None:
|
||||
member_ids = [member.value for member in group_data.members]
|
||||
update_form.user_ids = member_ids
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
if not updated_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update group",
|
||||
)
|
||||
|
||||
return group_to_scim(updated_group, request)
|
||||
|
||||
|
||||
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
|
||||
async def patch_group(
|
||||
group_id: str,
|
||||
request: Request,
|
||||
patch_data: SCIMPatchRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Update SCIM Group (partial update)"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
from open_webui.models.groups import GroupUpdateForm
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
user_ids=group.user_ids.copy() if group.user_ids else [],
|
||||
)
|
||||
|
||||
for operation in patch_data.Operations:
|
||||
op = operation.op.lower()
|
||||
path = operation.path
|
||||
value = operation.value
|
||||
|
||||
if op == "replace":
|
||||
if path == "displayName":
|
||||
update_form.name = value
|
||||
elif path == "members":
|
||||
# Replace all members
|
||||
update_form.user_ids = [member["value"] for member in value]
|
||||
elif op == "add":
|
||||
if path == "members":
|
||||
# Add members
|
||||
if isinstance(value, list):
|
||||
for member in value:
|
||||
if isinstance(member, dict) and "value" in member:
|
||||
if member["value"] not in update_form.user_ids:
|
||||
update_form.user_ids.append(member["value"])
|
||||
elif op == "remove":
|
||||
if path and path.startswith("members[value eq"):
|
||||
# Remove specific member
|
||||
member_id = path.split('"')[1]
|
||||
if member_id in update_form.user_ids:
|
||||
update_form.user_ids.remove(member_id)
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
if not updated_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update group",
|
||||
)
|
||||
|
||||
return group_to_scim(updated_group, request)
|
||||
|
||||
|
||||
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_group(
|
||||
group_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Delete SCIM Group"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
success = Groups.delete_group_by_id(group_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete group",
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
@ -198,14 +198,7 @@ async def generate_title(
|
|||
else:
|
||||
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
content = title_generation_template(template, form_data["messages"], user)
|
||||
|
||||
max_tokens = (
|
||||
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
|
||||
|
|
@ -289,14 +282,7 @@ async def generate_follow_ups(
|
|||
else:
|
||||
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = follow_up_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
content = follow_up_generation_template(template, form_data["messages"], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
|
|
@ -369,9 +355,7 @@ async def generate_chat_tags(
|
|||
else:
|
||||
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = tags_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
content = tags_generation_template(template, form_data["messages"], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
|
|
@ -437,13 +421,7 @@ async def generate_image_prompt(
|
|||
else:
|
||||
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = image_prompt_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
user={
|
||||
"name": user.name,
|
||||
},
|
||||
)
|
||||
content = image_prompt_generation_template(template, form_data["messages"], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
|
|
@ -524,9 +502,7 @@ async def generate_queries(
|
|||
else:
|
||||
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = query_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
content = query_generation_template(template, form_data["messages"], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
|
|
@ -611,9 +587,7 @@ async def generate_autocompletion(
|
|||
else:
|
||||
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = autocomplete_generation_template(
|
||||
template, prompt, messages, type, {"name": user.name}
|
||||
)
|
||||
content = autocomplete_generation_template(template, prompt, messages, type, user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
|
|
@ -675,14 +649,7 @@ async def generate_emoji(
|
|||
|
||||
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = emoji_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
content = emoji_generation_template(template, form_data["prompt"], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
|
|
@ -695,11 +662,11 @@ async def generate_emoji(
|
|||
"max_completion_tokens": 4,
|
||||
}
|
||||
),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.EMOJI_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import time
|
|||
import re
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
|
||||
from open_webui.models.tools import (
|
||||
ToolForm,
|
||||
|
|
@ -14,15 +16,14 @@ from open_webui.models.tools import (
|
|||
Tools,
|
||||
)
|
||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tool_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.utils.tools import get_tool_servers
|
||||
|
||||
from open_webui.utils.tools import get_tool_servers_data
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -31,6 +32,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
############################
|
||||
# GetTools
|
||||
############################
|
||||
|
|
@ -38,23 +40,14 @@ router = APIRouter()
|
|||
|
||||
@router.get("/", response_model=list[ToolUserResponse])
|
||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if not request.app.state.TOOL_SERVERS:
|
||||
# If the tool servers are not set, we need to set them
|
||||
# This is done only once when the server starts
|
||||
# This is done to avoid loading the tool servers every time
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
tools = Tools.get_tools()
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
|
||||
for server in await get_tool_servers(request):
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
**{
|
||||
"id": f"server:{server['idx']}",
|
||||
"user_id": f"server:{server['idx']}",
|
||||
"id": f"server:{server.get('id')}",
|
||||
"user_id": f"server:{server.get('id')}",
|
||||
"name": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("title", "Tool Server"),
|
||||
|
|
@ -64,7 +57,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
.get("description", ""),
|
||||
},
|
||||
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
|
||||
server["idx"]
|
||||
server.get("idx", 0)
|
||||
]
|
||||
.get("config", {})
|
||||
.get("access_control", None),
|
||||
|
|
@ -74,15 +67,17 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
)
|
||||
)
|
||||
|
||||
if user.role != "admin":
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
# Admin can see all tools
|
||||
return tools
|
||||
else:
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user.id
|
||||
or has_access(user.id, "read", tool.access_control)
|
||||
]
|
||||
|
||||
return tools
|
||||
return tools
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -92,7 +87,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/list", response_model=list[ToolUserResponse])
|
||||
async def get_tool_list(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
tools = Tools.get_tools()
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "write")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
import base64
|
||||
import io
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import Response, StreamingResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.groups import Groups
|
||||
|
|
@ -7,6 +15,7 @@ from open_webui.models.chats import Chats
|
|||
from open_webui.models.users import (
|
||||
UserModel,
|
||||
UserListResponse,
|
||||
UserInfoListResponse,
|
||||
UserRoleUpdateForm,
|
||||
Users,
|
||||
UserSettings,
|
||||
|
|
@ -20,9 +29,8 @@ from open_webui.socket.main import (
|
|||
get_user_active_status,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
||||
from open_webui.utils.access_control import get_permissions, has_permission
|
||||
|
|
@ -83,7 +91,7 @@ async def get_users(
|
|||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@router.get("/all", response_model=UserListResponse)
|
||||
@router.get("/all", response_model=UserInfoListResponse)
|
||||
async def get_all_users(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
|
|
@ -133,7 +141,9 @@ class SharingPermissions(BaseModel):
|
|||
|
||||
class ChatPermissions(BaseModel):
|
||||
controls: bool = True
|
||||
valves: bool = True
|
||||
system_prompt: bool = True
|
||||
params: bool = True
|
||||
file_upload: bool = True
|
||||
delete: bool = True
|
||||
edit: bool = True
|
||||
|
|
@ -326,6 +336,43 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserProfileImageById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{user_id}/profile/image")
|
||||
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user_id)
|
||||
if user:
|
||||
if user.profile_image_url:
|
||||
# check if it's url or base64
|
||||
if user.profile_image_url.startswith("http"):
|
||||
return Response(
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
headers={"Location": user.profile_image_url},
|
||||
)
|
||||
elif user.profile_image_url.startswith("data:image"):
|
||||
try:
|
||||
header, base64_data = user.profile_image_url.split(",", 1)
|
||||
image_data = base64.b64decode(base64_data)
|
||||
image_buffer = io.BytesIO(image_data)
|
||||
|
||||
return StreamingResponse(
|
||||
image_buffer,
|
||||
media_type="image/png",
|
||||
headers={"Content-Disposition": "inline; filename=image.png"},
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
return FileResponse(f"{STATIC_DIR}/user.png")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserActiveStatusById
|
||||
############################
|
||||
|
|
@ -454,3 +501,13 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
|||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserGroupsById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{user_id}/groups")
|
||||
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
return Groups.get_groups_by_member_id(user_id)
|
||||
|
|
|
|||
|
|
@ -22,9 +22,11 @@ from open_webui.env import (
|
|||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
WEBSOCKET_REDIS_CLUSTER,
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
WEBSOCKET_SENTINEL_PORT,
|
||||
WEBSOCKET_SENTINEL_HOSTS,
|
||||
REDIS_KEY_PREFIX,
|
||||
)
|
||||
from open_webui.utils.auth import decode_token
|
||||
from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
|
||||
|
|
@ -85,6 +87,7 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
redis_sentinels=get_sentinels_from_env(
|
||||
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||
),
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
async_mode=True,
|
||||
)
|
||||
|
||||
|
|
@ -92,19 +95,22 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||
)
|
||||
SESSION_POOL = RedisDict(
|
||||
"open-webui:session_pool",
|
||||
f"{REDIS_KEY_PREFIX}:session_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
USER_POOL = RedisDict(
|
||||
"open-webui:user_pool",
|
||||
f"{REDIS_KEY_PREFIX}:user_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
USAGE_POOL = RedisDict(
|
||||
"open-webui:usage_pool",
|
||||
f"{REDIS_KEY_PREFIX}:usage_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
|
||||
clean_up_lock = RedisLock(
|
||||
|
|
@ -112,6 +118,7 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
lock_name="usage_cleanup_lock",
|
||||
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
aquire_func = clean_up_lock.aquire_lock
|
||||
renew_func = clean_up_lock.renew_lock
|
||||
|
|
@ -126,7 +133,7 @@ else:
|
|||
|
||||
YDOC_MANAGER = YdocManager(
|
||||
redis=REDIS,
|
||||
redis_key_prefix="open-webui:ydoc:documents",
|
||||
redis_key_prefix=f"{REDIS_KEY_PREFIX}:ydoc:documents",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -259,7 +266,9 @@ async def connect(sid, environ, auth):
|
|||
user = Users.get_user_by_id(data["id"])
|
||||
|
||||
if user:
|
||||
SESSION_POOL[sid] = user.model_dump()
|
||||
SESSION_POOL[sid] = user.model_dump(
|
||||
exclude=["date_of_birth", "bio", "gender"]
|
||||
)
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
||||
else:
|
||||
|
|
@ -281,7 +290,7 @@ async def user_join(sid, data):
|
|||
if not user:
|
||||
return
|
||||
|
||||
SESSION_POOL[sid] = user.model_dump()
|
||||
SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"])
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
||||
else:
|
||||
|
|
@ -316,6 +325,37 @@ async def join_channel(sid, data):
|
|||
await sio.enter_room(sid, f"channel:{channel.id}")
|
||||
|
||||
|
||||
@sio.on("join-note")
|
||||
async def join_note(sid, data):
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
if not auth or "token" not in auth:
|
||||
return
|
||||
|
||||
token_data = decode_token(auth["token"])
|
||||
if token_data is None or "id" not in token_data:
|
||||
return
|
||||
|
||||
user = Users.get_user_by_id(token_data["id"])
|
||||
if not user:
|
||||
return
|
||||
|
||||
note = Notes.get_note_by_id(data["note_id"])
|
||||
if not note:
|
||||
log.error(f"Note {data['note_id']} not found for user {user.id}")
|
||||
return
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="read", access_control=note.access_control)
|
||||
):
|
||||
log.error(f"User {user.id} does not have access to note {data['note_id']}")
|
||||
return
|
||||
|
||||
log.debug(f"Joining note {note.id} for user {user.id}")
|
||||
await sio.enter_room(sid, f"note:{note.id}")
|
||||
|
||||
|
||||
@sio.on("channel-events")
|
||||
async def channel_events(sid, data):
|
||||
room = f"channel:{data['channel_id']}"
|
||||
|
|
@ -450,7 +490,7 @@ async def yjs_document_state(sid, data):
|
|||
room = f"doc_{document_id}"
|
||||
|
||||
active_session_ids = get_session_ids_from_room(room)
|
||||
print(active_session_ids)
|
||||
|
||||
if sid not in active_session_ids:
|
||||
log.warning(f"Session {sid} not in room {room}. Cannot send state.")
|
||||
return
|
||||
|
|
@ -520,7 +560,8 @@ async def yjs_document_update(sid, data):
|
|||
document_id, data.get("data", {}), SESSION_POOL.get(sid)
|
||||
)
|
||||
|
||||
await create_task(REDIS, debounced_save(), document_id)
|
||||
if data.get("data"):
|
||||
await create_task(REDIS, debounced_save(), document_id)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error in yjs_document_update: {e}")
|
||||
|
|
@ -549,7 +590,7 @@ async def yjs_document_leave(sid, data):
|
|||
)
|
||||
|
||||
if (
|
||||
YDOC_MANAGER.document_exists(document_id)
|
||||
await YDOC_MANAGER.document_exists(document_id)
|
||||
and len(await YDOC_MANAGER.get_users(document_id)) == 0
|
||||
):
|
||||
log.info(f"Cleaning up document {document_id} as no users are left")
|
||||
|
|
|
|||
|
|
@ -1,18 +1,30 @@
|
|||
import json
|
||||
import uuid
|
||||
from open_webui.utils.redis import get_redis_connection
|
||||
from open_webui.env import REDIS_KEY_PREFIX
|
||||
from typing import Optional, List, Tuple
|
||||
import pycrdt as Y
|
||||
|
||||
|
||||
class RedisLock:
|
||||
def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
|
||||
def __init__(
|
||||
self,
|
||||
redis_url,
|
||||
lock_name,
|
||||
timeout_secs,
|
||||
redis_sentinels=[],
|
||||
redis_cluster=False,
|
||||
):
|
||||
|
||||
self.lock_name = lock_name
|
||||
self.lock_id = str(uuid.uuid4())
|
||||
self.timeout_secs = timeout_secs
|
||||
self.lock_obtained = False
|
||||
self.redis = get_redis_connection(
|
||||
redis_url, redis_sentinels, decode_responses=True
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster=redis_cluster,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def aquire_lock(self):
|
||||
|
|
@ -35,10 +47,13 @@ class RedisLock:
|
|||
|
||||
|
||||
class RedisDict:
|
||||
def __init__(self, name, redis_url, redis_sentinels=[]):
|
||||
def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
|
||||
self.name = name
|
||||
self.redis = get_redis_connection(
|
||||
redis_url, redis_sentinels, decode_responses=True
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster=redis_cluster,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
|
|
@ -97,7 +112,7 @@ class YdocManager:
|
|||
def __init__(
|
||||
self,
|
||||
redis=None,
|
||||
redis_key_prefix: str = "open-webui:ydoc:documents",
|
||||
redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents",
|
||||
):
|
||||
self._updates = {}
|
||||
self._users = {}
|
||||
|
|
|
|||
BIN
backend/open_webui/static/user.png
Normal file
BIN
backend/open_webui/static/user.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.7 KiB |
|
|
@ -112,6 +112,9 @@ class S3StorageProvider(StorageProvider):
|
|||
"use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT,
|
||||
"addressing_style": S3_ADDRESSING_STYLE,
|
||||
},
|
||||
# KIT change - see https://github.com/boto/boto3/issues/4400#issuecomment-2600742103∆
|
||||
request_checksum_calculation="when_required",
|
||||
response_checksum_validation="when_required",
|
||||
)
|
||||
|
||||
# If access key and secret are provided, use them for authentication
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from redis.asyncio import Redis
|
|||
from fastapi import Request
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -19,9 +19,9 @@ tasks: Dict[str, asyncio.Task] = {}
|
|||
item_tasks = {}
|
||||
|
||||
|
||||
REDIS_TASKS_KEY = "open-webui:tasks"
|
||||
REDIS_ITEM_TASKS_KEY = "open-webui:tasks:item"
|
||||
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
|
||||
REDIS_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks"
|
||||
REDIS_ITEM_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks:item"
|
||||
REDIS_PUBSUB_CHANNEL = f"{REDIS_KEY_PREFIX}:tasks:commands"
|
||||
|
||||
|
||||
async def redis_task_command_listener(app):
|
||||
|
|
|
|||
793
backend/open_webui/test/util/test_redis.py
Normal file
793
backend/open_webui/test/util/test_redis.py
Normal file
|
|
@ -0,0 +1,793 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import redis
|
||||
from open_webui.utils.redis import (
|
||||
SentinelRedisProxy,
|
||||
parse_redis_service_url,
|
||||
get_redis_connection,
|
||||
get_sentinels_from_env,
|
||||
MAX_RETRY_COUNT,
|
||||
)
|
||||
import inspect
|
||||
|
||||
|
||||
class TestSentinelRedisProxy:
|
||||
"""Test Redis Sentinel failover functionality"""
|
||||
|
||||
def test_parse_redis_service_url_valid(self):
|
||||
"""Test parsing valid Redis service URL"""
|
||||
url = "redis://user:pass@mymaster:6379/0"
|
||||
result = parse_redis_service_url(url)
|
||||
|
||||
assert result["username"] == "user"
|
||||
assert result["password"] == "pass"
|
||||
assert result["service"] == "mymaster"
|
||||
assert result["port"] == 6379
|
||||
assert result["db"] == 0
|
||||
|
||||
def test_parse_redis_service_url_defaults(self):
|
||||
"""Test parsing Redis service URL with defaults"""
|
||||
url = "redis://mymaster"
|
||||
result = parse_redis_service_url(url)
|
||||
|
||||
assert result["username"] is None
|
||||
assert result["password"] is None
|
||||
assert result["service"] == "mymaster"
|
||||
assert result["port"] == 6379
|
||||
assert result["db"] == 0
|
||||
|
||||
def test_parse_redis_service_url_invalid_scheme(self):
|
||||
"""Test parsing invalid URL scheme"""
|
||||
with pytest.raises(ValueError, match="Invalid Redis URL scheme"):
|
||||
parse_redis_service_url("http://invalid")
|
||||
|
||||
def test_get_sentinels_from_env(self):
|
||||
"""Test parsing sentinel hosts from environment"""
|
||||
hosts = "sentinel1,sentinel2,sentinel3"
|
||||
port = "26379"
|
||||
|
||||
result = get_sentinels_from_env(hosts, port)
|
||||
expected = [("sentinel1", 26379), ("sentinel2", 26379), ("sentinel3", 26379)]
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_get_sentinels_from_env_empty(self):
|
||||
"""Test empty sentinel hosts"""
|
||||
result = get_sentinels_from_env(None, "26379")
|
||||
assert result == []
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_sentinel_redis_proxy_sync_success(self, mock_sentinel_class):
|
||||
"""Test successful sync operation with SentinelRedisProxy"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_master.get.return_value = "test_value"
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test attribute access
|
||||
get_method = proxy.__getattr__("get")
|
||||
result = get_method("test_key")
|
||||
|
||||
assert result == "test_value"
|
||||
mock_sentinel.master_for.assert_called_with("mymaster")
|
||||
mock_master.get.assert_called_with("test_key")
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentinel_redis_proxy_async_success(self, mock_sentinel_class):
|
||||
"""Test successful async operation with SentinelRedisProxy"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_master.get = AsyncMock(return_value="test_value")
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test async attribute access
|
||||
get_method = proxy.__getattr__("get")
|
||||
result = await get_method("test_key")
|
||||
|
||||
assert result == "test_value"
|
||||
mock_sentinel.master_for.assert_called_with("mymaster")
|
||||
mock_master.get.assert_called_with("test_key")
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_sentinel_redis_proxy_failover_retry(self, mock_sentinel_class):
|
||||
"""Test retry mechanism during failover"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# First call fails, second succeeds
|
||||
mock_master.get.side_effect = [
|
||||
redis.exceptions.ConnectionError("Master down"),
|
||||
"test_value",
|
||||
]
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
get_method = proxy.__getattr__("get")
|
||||
result = get_method("test_key")
|
||||
|
||||
assert result == "test_value"
|
||||
assert mock_master.get.call_count == 2
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_sentinel_redis_proxy_max_retries_exceeded(self, mock_sentinel_class):
|
||||
"""Test failure after max retries exceeded"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# All calls fail
|
||||
mock_master.get.side_effect = redis.exceptions.ConnectionError("Master down")
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
get_method = proxy.__getattr__("get")
|
||||
|
||||
with pytest.raises(redis.exceptions.ConnectionError):
|
||||
get_method("test_key")
|
||||
|
||||
assert mock_master.get.call_count == MAX_RETRY_COUNT
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_sentinel_redis_proxy_readonly_error_retry(self, mock_sentinel_class):
|
||||
"""Test retry on ReadOnlyError"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# First call gets ReadOnlyError (old master), second succeeds (new master)
|
||||
mock_master.get.side_effect = [
|
||||
redis.exceptions.ReadOnlyError("Read only"),
|
||||
"test_value",
|
||||
]
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
get_method = proxy.__getattr__("get")
|
||||
result = get_method("test_key")
|
||||
|
||||
assert result == "test_value"
|
||||
assert mock_master.get.call_count == 2
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_sentinel_redis_proxy_factory_methods(self, mock_sentinel_class):
|
||||
"""Test factory methods are passed through directly"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_pipeline = Mock()
|
||||
mock_master.pipeline.return_value = mock_pipeline
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Factory methods should be passed through without wrapping
|
||||
pipeline_method = proxy.__getattr__("pipeline")
|
||||
result = pipeline_method()
|
||||
|
||||
assert result == mock_pipeline
|
||||
mock_master.pipeline.assert_called_once()
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@patch("redis.from_url")
|
||||
def test_get_redis_connection_with_sentinel(
|
||||
self, mock_from_url, mock_sentinel_class
|
||||
):
|
||||
"""Test getting Redis connection with Sentinel"""
|
||||
mock_sentinel = Mock()
|
||||
mock_sentinel_class.return_value = mock_sentinel
|
||||
|
||||
sentinels = [("sentinel1", 26379), ("sentinel2", 26379)]
|
||||
redis_url = "redis://user:pass@mymaster:6379/0"
|
||||
|
||||
result = get_redis_connection(
|
||||
redis_url=redis_url, redis_sentinels=sentinels, async_mode=False
|
||||
)
|
||||
|
||||
assert isinstance(result, SentinelRedisProxy)
|
||||
mock_sentinel_class.assert_called_once()
|
||||
mock_from_url.assert_not_called()
|
||||
|
||||
@patch("redis.Redis.from_url")
|
||||
def test_get_redis_connection_without_sentinel(self, mock_from_url):
|
||||
"""Test getting Redis connection without Sentinel"""
|
||||
mock_redis = Mock()
|
||||
mock_from_url.return_value = mock_redis
|
||||
|
||||
redis_url = "redis://localhost:6379/0"
|
||||
|
||||
result = get_redis_connection(
|
||||
redis_url=redis_url, redis_sentinels=None, async_mode=False
|
||||
)
|
||||
|
||||
assert result == mock_redis
|
||||
mock_from_url.assert_called_once_with(redis_url, decode_responses=True)
|
||||
|
||||
@patch("redis.asyncio.from_url")
|
||||
def test_get_redis_connection_without_sentinel_async(self, mock_from_url):
|
||||
"""Test getting async Redis connection without Sentinel"""
|
||||
mock_redis = Mock()
|
||||
mock_from_url.return_value = mock_redis
|
||||
|
||||
redis_url = "redis://localhost:6379/0"
|
||||
|
||||
result = get_redis_connection(
|
||||
redis_url=redis_url, redis_sentinels=None, async_mode=True
|
||||
)
|
||||
|
||||
assert result == mock_redis
|
||||
mock_from_url.assert_called_once_with(redis_url, decode_responses=True)
|
||||
|
||||
|
||||
class TestSentinelRedisProxyCommands:
|
||||
"""Test Redis commands through SentinelRedisProxy"""
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_hash_commands_sync(self, mock_sentinel_class):
|
||||
"""Test Redis hash commands in sync mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# Mock hash command responses
|
||||
mock_master.hset.return_value = 1
|
||||
mock_master.hget.return_value = "test_value"
|
||||
mock_master.hgetall.return_value = {"key1": "value1", "key2": "value2"}
|
||||
mock_master.hdel.return_value = 1
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test hset
|
||||
hset_method = proxy.__getattr__("hset")
|
||||
result = hset_method("test_hash", "field1", "value1")
|
||||
assert result == 1
|
||||
mock_master.hset.assert_called_with("test_hash", "field1", "value1")
|
||||
|
||||
# Test hget
|
||||
hget_method = proxy.__getattr__("hget")
|
||||
result = hget_method("test_hash", "field1")
|
||||
assert result == "test_value"
|
||||
mock_master.hget.assert_called_with("test_hash", "field1")
|
||||
|
||||
# Test hgetall
|
||||
hgetall_method = proxy.__getattr__("hgetall")
|
||||
result = hgetall_method("test_hash")
|
||||
assert result == {"key1": "value1", "key2": "value2"}
|
||||
mock_master.hgetall.assert_called_with("test_hash")
|
||||
|
||||
# Test hdel
|
||||
hdel_method = proxy.__getattr__("hdel")
|
||||
result = hdel_method("test_hash", "field1")
|
||||
assert result == 1
|
||||
mock_master.hdel.assert_called_with("test_hash", "field1")
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_hash_commands_async(self, mock_sentinel_class):
|
||||
"""Test Redis hash commands in async mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# Mock async hash command responses
|
||||
mock_master.hset = AsyncMock(return_value=1)
|
||||
mock_master.hget = AsyncMock(return_value="test_value")
|
||||
mock_master.hgetall = AsyncMock(
|
||||
return_value={"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test hset
|
||||
hset_method = proxy.__getattr__("hset")
|
||||
result = await hset_method("test_hash", "field1", "value1")
|
||||
assert result == 1
|
||||
mock_master.hset.assert_called_with("test_hash", "field1", "value1")
|
||||
|
||||
# Test hget
|
||||
hget_method = proxy.__getattr__("hget")
|
||||
result = await hget_method("test_hash", "field1")
|
||||
assert result == "test_value"
|
||||
mock_master.hget.assert_called_with("test_hash", "field1")
|
||||
|
||||
# Test hgetall
|
||||
hgetall_method = proxy.__getattr__("hgetall")
|
||||
result = await hgetall_method("test_hash")
|
||||
assert result == {"key1": "value1", "key2": "value2"}
|
||||
mock_master.hgetall.assert_called_with("test_hash")
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_string_commands_sync(self, mock_sentinel_class):
|
||||
"""Test Redis string commands in sync mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# Mock string command responses
|
||||
mock_master.set.return_value = True
|
||||
mock_master.get.return_value = "test_value"
|
||||
mock_master.delete.return_value = 1
|
||||
mock_master.exists.return_value = True
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test set
|
||||
set_method = proxy.__getattr__("set")
|
||||
result = set_method("test_key", "test_value")
|
||||
assert result is True
|
||||
mock_master.set.assert_called_with("test_key", "test_value")
|
||||
|
||||
# Test get
|
||||
get_method = proxy.__getattr__("get")
|
||||
result = get_method("test_key")
|
||||
assert result == "test_value"
|
||||
mock_master.get.assert_called_with("test_key")
|
||||
|
||||
# Test delete
|
||||
delete_method = proxy.__getattr__("delete")
|
||||
result = delete_method("test_key")
|
||||
assert result == 1
|
||||
mock_master.delete.assert_called_with("test_key")
|
||||
|
||||
# Test exists
|
||||
exists_method = proxy.__getattr__("exists")
|
||||
result = exists_method("test_key")
|
||||
assert result is True
|
||||
mock_master.exists.assert_called_with("test_key")
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_list_commands_sync(self, mock_sentinel_class):
|
||||
"""Test Redis list commands in sync mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# Mock list command responses
|
||||
mock_master.lpush.return_value = 1
|
||||
mock_master.rpop.return_value = "test_value"
|
||||
mock_master.llen.return_value = 5
|
||||
mock_master.lrange.return_value = ["item1", "item2", "item3"]
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test lpush
|
||||
lpush_method = proxy.__getattr__("lpush")
|
||||
result = lpush_method("test_list", "item1")
|
||||
assert result == 1
|
||||
mock_master.lpush.assert_called_with("test_list", "item1")
|
||||
|
||||
# Test rpop
|
||||
rpop_method = proxy.__getattr__("rpop")
|
||||
result = rpop_method("test_list")
|
||||
assert result == "test_value"
|
||||
mock_master.rpop.assert_called_with("test_list")
|
||||
|
||||
# Test llen
|
||||
llen_method = proxy.__getattr__("llen")
|
||||
result = llen_method("test_list")
|
||||
assert result == 5
|
||||
mock_master.llen.assert_called_with("test_list")
|
||||
|
||||
# Test lrange
|
||||
lrange_method = proxy.__getattr__("lrange")
|
||||
result = lrange_method("test_list", 0, -1)
|
||||
assert result == ["item1", "item2", "item3"]
|
||||
mock_master.lrange.assert_called_with("test_list", 0, -1)
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_pubsub_commands_sync(self, mock_sentinel_class):
|
||||
"""Test Redis pubsub commands in sync mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_pubsub = Mock()
|
||||
|
||||
# Mock pubsub responses
|
||||
mock_master.pubsub.return_value = mock_pubsub
|
||||
mock_master.publish.return_value = 1
|
||||
mock_pubsub.subscribe.return_value = None
|
||||
mock_pubsub.get_message.return_value = {"type": "message", "data": "test_data"}
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test pubsub (factory method - should pass through)
|
||||
pubsub_method = proxy.__getattr__("pubsub")
|
||||
result = pubsub_method()
|
||||
assert result == mock_pubsub
|
||||
mock_master.pubsub.assert_called_once()
|
||||
|
||||
# Test publish
|
||||
publish_method = proxy.__getattr__("publish")
|
||||
result = publish_method("test_channel", "test_message")
|
||||
assert result == 1
|
||||
mock_master.publish.assert_called_with("test_channel", "test_message")
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_pipeline_commands_sync(self, mock_sentinel_class):
|
||||
"""Test Redis pipeline commands in sync mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_pipeline = Mock()
|
||||
|
||||
# Mock pipeline responses
|
||||
mock_master.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.set.return_value = mock_pipeline
|
||||
mock_pipeline.get.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [True, "test_value"]
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test pipeline (factory method - should pass through)
|
||||
pipeline_method = proxy.__getattr__("pipeline")
|
||||
result = pipeline_method()
|
||||
assert result == mock_pipeline
|
||||
mock_master.pipeline.assert_called_once()
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_commands_with_failover_retry(self, mock_sentinel_class):
|
||||
"""Test Redis commands with failover retry mechanism"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# First call fails with connection error, second succeeds
|
||||
mock_master.hget.side_effect = [
|
||||
redis.exceptions.ConnectionError("Connection failed"),
|
||||
"recovered_value",
|
||||
]
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test hget with retry
|
||||
hget_method = proxy.__getattr__("hget")
|
||||
result = hget_method("test_hash", "field1")
|
||||
|
||||
assert result == "recovered_value"
|
||||
assert mock_master.hget.call_count == 2
|
||||
|
||||
# Verify both calls were made with same parameters
|
||||
expected_calls = [(("test_hash", "field1"),), (("test_hash", "field1"),)]
|
||||
actual_calls = [call.args for call in mock_master.hget.call_args_list]
|
||||
assert actual_calls == expected_calls
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
def test_commands_with_readonly_error_retry(self, mock_sentinel_class):
|
||||
"""Test Redis commands with ReadOnlyError retry mechanism"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# First call fails with ReadOnlyError, second succeeds
|
||||
mock_master.hset.side_effect = [
|
||||
redis.exceptions.ReadOnlyError(
|
||||
"READONLY You can't write against a read only replica"
|
||||
),
|
||||
1,
|
||||
]
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
|
||||
|
||||
# Test hset with retry
|
||||
hset_method = proxy.__getattr__("hset")
|
||||
result = hset_method("test_hash", "field1", "value1")
|
||||
|
||||
assert result == 1
|
||||
assert mock_master.hset.call_count == 2
|
||||
|
||||
# Verify both calls were made with same parameters
|
||||
expected_calls = [
|
||||
(("test_hash", "field1", "value1"),),
|
||||
(("test_hash", "field1", "value1"),),
|
||||
]
|
||||
actual_calls = [call.args for call in mock_master.hset.call_args_list]
|
||||
assert actual_calls == expected_calls
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_commands_with_failover_retry(self, mock_sentinel_class):
|
||||
"""Test async Redis commands with failover retry mechanism"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# First call fails with connection error, second succeeds
|
||||
mock_master.hget = AsyncMock(
|
||||
side_effect=[
|
||||
redis.exceptions.ConnectionError("Connection failed"),
|
||||
"recovered_value",
|
||||
]
|
||||
)
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test async hget with retry
|
||||
hget_method = proxy.__getattr__("hget")
|
||||
result = await hget_method("test_hash", "field1")
|
||||
|
||||
assert result == "recovered_value"
|
||||
assert mock_master.hget.call_count == 2
|
||||
|
||||
# Verify both calls were made with same parameters
|
||||
expected_calls = [(("test_hash", "field1"),), (("test_hash", "field1"),)]
|
||||
actual_calls = [call.args for call in mock_master.hget.call_args_list]
|
||||
assert actual_calls == expected_calls
|
||||
|
||||
|
||||
class TestSentinelRedisProxyFactoryMethods:
|
||||
"""Test Redis factory methods in async mode - these are special cases that remain sync"""
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_pubsub_factory_method_async(self, mock_sentinel_class):
|
||||
"""Test pubsub factory method in async mode - should pass through without wrapping"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_pubsub = Mock()
|
||||
|
||||
# Mock pubsub factory method
|
||||
mock_master.pubsub.return_value = mock_pubsub
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test pubsub factory method - should NOT be wrapped as async
|
||||
pubsub_method = proxy.__getattr__("pubsub")
|
||||
result = pubsub_method()
|
||||
|
||||
assert result == mock_pubsub
|
||||
mock_master.pubsub.assert_called_once()
|
||||
|
||||
# Verify it's not wrapped as async (no await needed)
|
||||
assert not inspect.iscoroutine(result)
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_factory_method_async(self, mock_sentinel_class):
|
||||
"""Test pipeline factory method in async mode - should pass through without wrapping"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_pipeline = Mock()
|
||||
|
||||
# Mock pipeline factory method
|
||||
mock_master.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.set.return_value = mock_pipeline
|
||||
mock_pipeline.get.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [True, "test_value"]
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test pipeline factory method - should NOT be wrapped as async
|
||||
pipeline_method = proxy.__getattr__("pipeline")
|
||||
result = pipeline_method()
|
||||
|
||||
assert result == mock_pipeline
|
||||
mock_master.pipeline.assert_called_once()
|
||||
|
||||
# Verify it's not wrapped as async (no await needed)
|
||||
assert not inspect.iscoroutine(result)
|
||||
|
||||
# Test pipeline usage (these should also be sync)
|
||||
pipeline_result = result.set("key", "value").get("key").execute()
|
||||
assert pipeline_result == [True, "test_value"]
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_methods_vs_regular_commands_async(self, mock_sentinel_class):
|
||||
"""Test that factory methods behave differently from regular commands in async mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# Mock both factory method and regular command
|
||||
mock_pubsub = Mock()
|
||||
mock_master.pubsub.return_value = mock_pubsub
|
||||
mock_master.get = AsyncMock(return_value="test_value")
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test factory method - should NOT be wrapped
|
||||
pubsub_method = proxy.__getattr__("pubsub")
|
||||
pubsub_result = pubsub_method()
|
||||
|
||||
# Test regular command - should be wrapped as async
|
||||
get_method = proxy.__getattr__("get")
|
||||
get_result = get_method("test_key")
|
||||
|
||||
# Factory method returns directly
|
||||
assert pubsub_result == mock_pubsub
|
||||
assert not inspect.iscoroutine(pubsub_result)
|
||||
|
||||
# Regular command returns coroutine
|
||||
assert inspect.iscoroutine(get_result)
|
||||
|
||||
# Regular command needs await
|
||||
actual_value = await get_result
|
||||
assert actual_value == "test_value"
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_methods_with_failover_async(self, mock_sentinel_class):
|
||||
"""Test factory methods with failover in async mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# First call fails, second succeeds
|
||||
mock_pubsub = Mock()
|
||||
mock_master.pubsub.side_effect = [
|
||||
redis.exceptions.ConnectionError("Connection failed"),
|
||||
mock_pubsub,
|
||||
]
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test pubsub factory method with failover
|
||||
pubsub_method = proxy.__getattr__("pubsub")
|
||||
result = pubsub_method()
|
||||
|
||||
assert result == mock_pubsub
|
||||
assert mock_master.pubsub.call_count == 2 # Retry happened
|
||||
|
||||
# Verify it's still not wrapped as async after retry
|
||||
assert not inspect.iscoroutine(result)
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_factory_method_async(self, mock_sentinel_class):
|
||||
"""Test monitor factory method in async mode - should pass through without wrapping"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_monitor = Mock()
|
||||
|
||||
# Mock monitor factory method
|
||||
mock_master.monitor.return_value = mock_monitor
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test monitor factory method - should NOT be wrapped as async
|
||||
monitor_method = proxy.__getattr__("monitor")
|
||||
result = monitor_method()
|
||||
|
||||
assert result == mock_monitor
|
||||
mock_master.monitor.assert_called_once()
|
||||
|
||||
# Verify it's not wrapped as async (no await needed)
|
||||
assert not inspect.iscoroutine(result)
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_factory_method_async(self, mock_sentinel_class):
|
||||
"""Test client factory method in async mode - should pass through without wrapping"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_client = Mock()
|
||||
|
||||
# Mock client factory method
|
||||
mock_master.client.return_value = mock_client
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test client factory method - should NOT be wrapped as async
|
||||
client_method = proxy.__getattr__("client")
|
||||
result = client_method()
|
||||
|
||||
assert result == mock_client
|
||||
mock_master.client.assert_called_once()
|
||||
|
||||
# Verify it's not wrapped as async (no await needed)
|
||||
assert not inspect.iscoroutine(result)
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_factory_method_async(self, mock_sentinel_class):
|
||||
"""Test transaction factory method in async mode - should pass through without wrapping"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
mock_transaction = Mock()
|
||||
|
||||
# Mock transaction factory method
|
||||
mock_master.transaction.return_value = mock_transaction
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test transaction factory method - should NOT be wrapped as async
|
||||
transaction_method = proxy.__getattr__("transaction")
|
||||
result = transaction_method()
|
||||
|
||||
assert result == mock_transaction
|
||||
mock_master.transaction.assert_called_once()
|
||||
|
||||
# Verify it's not wrapped as async (no await needed)
|
||||
assert not inspect.iscoroutine(result)
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_factory_methods_async(self, mock_sentinel_class):
|
||||
"""Test all factory methods in async mode - comprehensive test"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# Mock all factory methods
|
||||
mock_objects = {
|
||||
"pipeline": Mock(),
|
||||
"pubsub": Mock(),
|
||||
"monitor": Mock(),
|
||||
"client": Mock(),
|
||||
"transaction": Mock(),
|
||||
}
|
||||
|
||||
for method_name, mock_obj in mock_objects.items():
|
||||
setattr(mock_master, method_name, Mock(return_value=mock_obj))
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Test all factory methods
|
||||
for method_name, expected_obj in mock_objects.items():
|
||||
method = proxy.__getattr__(method_name)
|
||||
result = method()
|
||||
|
||||
assert result == expected_obj
|
||||
assert not inspect.iscoroutine(result)
|
||||
getattr(mock_master, method_name).assert_called_once()
|
||||
|
||||
# Reset mock for next iteration
|
||||
getattr(mock_master, method_name).reset_mock()
|
||||
|
||||
@patch("redis.sentinel.Sentinel")
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_factory_and_regular_commands_async(self, mock_sentinel_class):
|
||||
"""Test using both factory methods and regular commands in async mode"""
|
||||
mock_sentinel = Mock()
|
||||
mock_master = Mock()
|
||||
|
||||
# Mock pipeline factory and regular commands
|
||||
mock_pipeline = Mock()
|
||||
mock_master.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.set.return_value = mock_pipeline
|
||||
mock_pipeline.get.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [True, "pipeline_value"]
|
||||
|
||||
mock_master.get = AsyncMock(return_value="regular_value")
|
||||
|
||||
mock_sentinel.master_for.return_value = mock_master
|
||||
|
||||
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
|
||||
|
||||
# Use factory method (sync)
|
||||
pipeline = proxy.__getattr__("pipeline")()
|
||||
pipeline_result = pipeline.set("key1", "value1").get("key1").execute()
|
||||
|
||||
# Use regular command (async)
|
||||
get_method = proxy.__getattr__("get")
|
||||
regular_result = await get_method("key2")
|
||||
|
||||
# Verify both work correctly
|
||||
assert pipeline_result == [True, "pipeline_value"]
|
||||
assert regular_result == "regular_value"
|
||||
|
||||
# Verify calls
|
||||
mock_master.pipeline.assert_called_once()
|
||||
mock_master.get.assert_called_with("key2")
|
||||
|
|
@ -60,8 +60,7 @@ def get_permissions(
|
|||
|
||||
# Combine permissions from all user groups
|
||||
for group in user_groups:
|
||||
group_permissions = group.permissions or {}
|
||||
permissions = combine_permissions(permissions, group_permissions)
|
||||
permissions = combine_permissions(permissions, group.permissions or {})
|
||||
|
||||
# Ensure all fields from default_permissions are present and filled in
|
||||
permissions = fill_missing_permissions(permissions, default_permissions)
|
||||
|
|
@ -96,8 +95,7 @@ def has_permission(
|
|||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in user_groups:
|
||||
group_permissions = group.permissions
|
||||
if get_permission(group_permissions, permission_hierarchy):
|
||||
if get_permission(group.permissions or {}, permission_hierarchy):
|
||||
return True
|
||||
|
||||
# Check default permissions afterward if the group permissions don't allow it
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ class AuditLoggingMiddleware:
|
|||
|
||||
try:
|
||||
user = get_current_user(
|
||||
request, None, get_http_authorization_cred(auth_header)
|
||||
request, None, None, get_http_authorization_cred(auth_header)
|
||||
)
|
||||
return user
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,12 @@ import requests
|
|||
import os
|
||||
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
import json
|
||||
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import pytz
|
||||
from pytz import UTC
|
||||
|
|
@ -18,7 +24,11 @@ from opentelemetry import trace
|
|||
from open_webui.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
from open_webui.env import (
|
||||
OFFLINE_MODE,
|
||||
LICENSE_BLOB,
|
||||
pk,
|
||||
WEBUI_SECRET_KEY,
|
||||
TRUSTED_SIGNATURE_KEY,
|
||||
STATIC_DIR,
|
||||
|
|
@ -74,6 +84,18 @@ def override_static(path: str, content: str):
|
|||
|
||||
|
||||
def get_license_data(app, key):
|
||||
def data_handler(data):
|
||||
for k, v in data.items():
|
||||
if k == "resources":
|
||||
for p, c in v.items():
|
||||
globals().get("override_static", lambda a, b: None)(p, c)
|
||||
elif k == "count":
|
||||
setattr(app.state, "USER_COUNT", v)
|
||||
elif k == "name":
|
||||
setattr(app.state, "WEBUI_NAME", v)
|
||||
elif k == "metadata":
|
||||
setattr(app.state, "LICENSE_METADATA", v)
|
||||
|
||||
def handler(u):
|
||||
res = requests.post(
|
||||
f"{u}/api/v1/license/",
|
||||
|
|
@ -83,16 +105,7 @@ def get_license_data(app, key):
|
|||
|
||||
if getattr(res, "ok", False):
|
||||
payload = getattr(res, "json", lambda: {})()
|
||||
for k, v in payload.items():
|
||||
if k == "resources":
|
||||
for p, c in v.items():
|
||||
globals().get("override_static", lambda a, b: None)(p, c)
|
||||
elif k == "count":
|
||||
setattr(app.state, "USER_COUNT", v)
|
||||
elif k == "name":
|
||||
setattr(app.state, "WEBUI_NAME", v)
|
||||
elif k == "metadata":
|
||||
setattr(app.state, "LICENSE_METADATA", v)
|
||||
data_handler(payload)
|
||||
return True
|
||||
else:
|
||||
log.error(
|
||||
|
|
@ -100,13 +113,44 @@ def get_license_data(app, key):
|
|||
)
|
||||
|
||||
if key:
|
||||
us = ["https://api.openwebui.com", "https://licenses.api.openwebui.com"]
|
||||
us = [
|
||||
"https://api.openwebui.com",
|
||||
"https://licenses.api.openwebui.com",
|
||||
]
|
||||
try:
|
||||
for u in us:
|
||||
if handler(u):
|
||||
return True
|
||||
except Exception as ex:
|
||||
log.exception(f"License: Uncaught Exception: {ex}")
|
||||
|
||||
try:
|
||||
if LICENSE_BLOB:
|
||||
nl = 12
|
||||
kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()
|
||||
|
||||
def nt(b):
|
||||
return b[:nl], b[nl:]
|
||||
|
||||
lb = base64.b64decode(LICENSE_BLOB)
|
||||
ln, lt = nt(lb)
|
||||
|
||||
aesgcm = AESGCM(kb)
|
||||
p = json.loads(aesgcm.decrypt(ln, lt, None))
|
||||
pk.verify(base64.b64decode(p["s"]), p["p"].encode())
|
||||
|
||||
pb = base64.b64decode(p["p"])
|
||||
pn, pt = nt(pb)
|
||||
|
||||
data = json.loads(aesgcm.decrypt(pn, pt, None).decode())
|
||||
if not data.get("exp") and data.get("exp") < datetime.now().date():
|
||||
return False
|
||||
|
||||
data_handler(data)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"License: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -177,7 +221,7 @@ def get_current_user(
|
|||
token = request.cookies.get("token")
|
||||
|
||||
if token is None:
|
||||
raise HTTPException(status_code=403, detail="Not authenticated")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
# auth by api key
|
||||
if token.startswith("sk-"):
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@ import sys
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
from opentelemetry import trace
|
||||
from open_webui.env import (
|
||||
AUDIT_UVICORN_LOGGER_NAMES,
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
AUDIT_LOG_LEVEL,
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
ENABLE_OTEL,
|
||||
ENABLE_OTEL_LOGS,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -28,13 +29,16 @@ def stdout_format(record: "Record") -> str:
|
|||
Returns:
|
||||
str: A formatted log string intended for stdout.
|
||||
"""
|
||||
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||
if record["extra"]:
|
||||
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||
extra_format = " - {extra[extra_json]}"
|
||||
else:
|
||||
extra_format = ""
|
||||
return (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level> - {extra[extra_json]}"
|
||||
"\n{exception}"
|
||||
"<level>{message}</level>" + extra_format + "\n{exception}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -60,9 +64,24 @@ class InterceptHandler(logging.Handler):
|
|||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
logger.opt(depth=depth, exception=record.exc_info).bind(
|
||||
**self._get_extras()
|
||||
).log(level, record.getMessage())
|
||||
if ENABLE_OTEL and ENABLE_OTEL_LOGS:
|
||||
from open_webui.utils.telemetry.logs import otel_handler
|
||||
|
||||
otel_handler.emit(record)
|
||||
|
||||
def _get_extras(self):
|
||||
if not ENABLE_OTEL:
|
||||
return {}
|
||||
|
||||
extras = {}
|
||||
context = trace.get_current_span().get_span_context()
|
||||
if context.is_valid:
|
||||
extras["trace_id"] = trace.format_trace_id(context.trace_id)
|
||||
extras["span_id"] = trace.format_span_id(context.span_id)
|
||||
return extras
|
||||
|
||||
|
||||
def file_format(record: "Record"):
|
||||
|
|
@ -113,7 +132,6 @@ def start_logger():
|
|||
format=stdout_format,
|
||||
filter=lambda record: "auditable" not in record["extra"],
|
||||
)
|
||||
|
||||
if AUDIT_LOG_LEVEL != "NONE":
|
||||
try:
|
||||
logger.add(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import logging
|
|||
import sys
|
||||
import os
|
||||
import base64
|
||||
import textwrap
|
||||
|
||||
import asyncio
|
||||
from aiocache import cached
|
||||
|
|
@ -19,7 +20,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||
|
||||
|
||||
from open_webui.models.chats import Chats
|
||||
|
|
@ -73,6 +74,7 @@ from open_webui.utils.misc import (
|
|||
add_or_update_user_message,
|
||||
get_last_user_message,
|
||||
get_last_assistant_message,
|
||||
get_system_message,
|
||||
prepend_to_first_user_message_content,
|
||||
convert_logit_bias_input_to_json,
|
||||
)
|
||||
|
|
@ -83,17 +85,19 @@ from open_webui.utils.filter import (
|
|||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.utils.payload import apply_system_prompt_to_body
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT,
|
||||
CODE_INTERPRETER_BLOCKED_MODULES,
|
||||
)
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
ENABLE_REALTIME_CHAT_SAVE,
|
||||
)
|
||||
|
|
@ -653,13 +657,13 @@ async def chat_completion_files_handler(
|
|||
),
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=(
|
||||
lambda sentences: (
|
||||
request.app.state.RERANKING_FUNCTION(
|
||||
(
|
||||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
else None
|
||||
)
|
||||
if request.app.state.RERANKING_FUNCTION
|
||||
else None
|
||||
),
|
||||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
|
|
@ -683,6 +687,7 @@ def apply_params_to_form_data(form_data, model):
|
|||
|
||||
open_webui_params = {
|
||||
"stream_response": bool,
|
||||
"stream_delta_chunk_size": int,
|
||||
"function_calling": str,
|
||||
"system": str,
|
||||
}
|
||||
|
|
@ -733,6 +738,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
system_message = get_system_message(form_data.get("messages", []))
|
||||
if system_message:
|
||||
try:
|
||||
form_data = apply_system_prompt_to_body(
|
||||
system_message.get("content"), form_data, metadata, user
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
|
|
@ -774,8 +788,8 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
|
||||
if folder and folder.data:
|
||||
if "system_prompt" in folder.data:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
folder.data["system_prompt"], form_data["messages"]
|
||||
form_data = apply_system_prompt_to_body(
|
||||
folder.data["system_prompt"], form_data, metadata, user
|
||||
)
|
||||
if "files" in folder.data:
|
||||
form_data["files"] = [
|
||||
|
|
@ -905,7 +919,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
tools_dict = {}
|
||||
|
||||
if tool_ids:
|
||||
tools_dict = get_tools(
|
||||
tools_dict = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
|
|
@ -929,7 +943,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
}
|
||||
|
||||
if tools_dict:
|
||||
if metadata.get("function_calling") == "native":
|
||||
if metadata.get("params", {}).get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools_dict
|
||||
form_data["tools"] = [
|
||||
|
|
@ -986,25 +1000,24 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
if prompt is None:
|
||||
raise Exception("No user message found")
|
||||
|
||||
if context_string == "":
|
||||
if request.app.state.config.RELEVANCE_THRESHOLD == 0:
|
||||
log.debug(
|
||||
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
||||
)
|
||||
else:
|
||||
if context_string != "":
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model.get("owned_by") == "ollama":
|
||||
form_data["messages"] = prepend_to_first_user_message_content(
|
||||
rag_template(
|
||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
request.app.state.config.RAG_TEMPLATE,
|
||||
context_string,
|
||||
prompt,
|
||||
),
|
||||
form_data["messages"],
|
||||
)
|
||||
else:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
rag_template(
|
||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
request.app.state.config.RAG_TEMPLATE,
|
||||
context_string,
|
||||
prompt,
|
||||
),
|
||||
form_data["messages"],
|
||||
)
|
||||
|
|
@ -1251,91 +1264,111 @@ async def process_chat_response(
|
|||
# Non-streaming response
|
||||
if not isinstance(response, StreamingResponse):
|
||||
if event_emitter:
|
||||
if "error" in response:
|
||||
error = response["error"].get("detail", response["error"])
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"error": {"content": error},
|
||||
},
|
||||
)
|
||||
if isinstance(response, dict) or isinstance(response, JSONResponse):
|
||||
|
||||
if "selected_model_id" in response:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": response["selected_model_id"],
|
||||
},
|
||||
)
|
||||
if isinstance(response, JSONResponse) and isinstance(
|
||||
response.body, bytes
|
||||
):
|
||||
try:
|
||||
response_data = json.loads(response.body.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"error": {"detail": "Invalid JSON response"}}
|
||||
else:
|
||||
response_data = response
|
||||
|
||||
choices = response.get("choices", [])
|
||||
if choices and choices[0].get("message", {}).get("content"):
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
if content:
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": response,
|
||||
}
|
||||
)
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"done": True,
|
||||
"content": content,
|
||||
"title": title,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Save message in the database
|
||||
if "error" in response_data:
|
||||
error = response_data["error"].get("detail", response_data["error"])
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"error": {"content": error},
|
||||
},
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
"action": "chat",
|
||||
"message": content,
|
||||
if "selected_model_id" in response_data:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": response_data["selected_model_id"],
|
||||
},
|
||||
)
|
||||
|
||||
choices = response_data.get("choices", [])
|
||||
if choices and choices[0].get("message", {}).get("content"):
|
||||
content = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
if content:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": response_data,
|
||||
}
|
||||
)
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"done": True,
|
||||
"content": content,
|
||||
"title": title,
|
||||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||||
},
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
await background_tasks_handler()
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
if events and isinstance(events, list) and isinstance(response, dict):
|
||||
extra_response = {}
|
||||
for event in events:
|
||||
if isinstance(event, dict):
|
||||
extra_response.update(event)
|
||||
else:
|
||||
extra_response[event] = True
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
await post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
"action": "chat",
|
||||
"message": content,
|
||||
"title": title,
|
||||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||||
},
|
||||
)
|
||||
|
||||
response = {
|
||||
**extra_response,
|
||||
**response,
|
||||
}
|
||||
await background_tasks_handler()
|
||||
|
||||
if events and isinstance(events, list):
|
||||
extra_response = {}
|
||||
for event in events:
|
||||
if isinstance(event, dict):
|
||||
extra_response.update(event)
|
||||
else:
|
||||
extra_response[event] = True
|
||||
|
||||
response_data = {
|
||||
**extra_response,
|
||||
**response_data,
|
||||
}
|
||||
|
||||
if isinstance(response, dict):
|
||||
response = response_data
|
||||
if isinstance(response, JSONResponse):
|
||||
response = JSONResponse(
|
||||
content=response_data,
|
||||
headers=response.headers,
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
|
|
@ -1381,14 +1414,6 @@ async def process_chat_response(
|
|||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
model_id = form_data.get("model", "")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"model": model_id,
|
||||
},
|
||||
)
|
||||
|
||||
def split_content_and_whitespace(content):
|
||||
content_stripped = content.rstrip()
|
||||
original_whitespace = (
|
||||
|
|
@ -1410,13 +1435,18 @@ async def process_chat_response(
|
|||
|
||||
for block in content_blocks:
|
||||
if block["type"] == "text":
|
||||
content = f"{content}{block['content'].strip()}\n"
|
||||
block_content = block["content"].strip()
|
||||
if block_content:
|
||||
content = f"{content}{block_content}\n"
|
||||
elif block["type"] == "tool_calls":
|
||||
attributes = block.get("attributes", {})
|
||||
|
||||
tool_calls = block.get("content", [])
|
||||
results = block.get("results", [])
|
||||
|
||||
if content and not content.endswith("\n"):
|
||||
content += "\n"
|
||||
|
||||
if results:
|
||||
|
||||
tool_calls_display_content = ""
|
||||
|
|
@ -1439,12 +1469,12 @@ async def process_chat_response(
|
|||
break
|
||||
|
||||
if tool_result:
|
||||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result, ensure_ascii=False))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}">\n<summary>Tool Executed</summary>\n</details>\n'
|
||||
tool_calls_display_content = f'{tool_calls_display_content}<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result, ensure_ascii=False))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}">\n<summary>Tool Executed</summary>\n</details>\n'
|
||||
else:
|
||||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>'
|
||||
tool_calls_display_content = f'{tool_calls_display_content}<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>\n'
|
||||
|
||||
if not raw:
|
||||
content = f"{content}\n{tool_calls_display_content}\n\n"
|
||||
content = f"{content}{tool_calls_display_content}"
|
||||
else:
|
||||
tool_calls_display_content = ""
|
||||
|
||||
|
|
@ -1457,10 +1487,10 @@ async def process_chat_response(
|
|||
"arguments", ""
|
||||
)
|
||||
|
||||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>'
|
||||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>\n'
|
||||
|
||||
if not raw:
|
||||
content = f"{content}\n{tool_calls_display_content}\n\n"
|
||||
content = f"{content}{tool_calls_display_content}"
|
||||
|
||||
elif block["type"] == "reasoning":
|
||||
reasoning_display_content = "\n".join(
|
||||
|
|
@ -1470,16 +1500,26 @@ async def process_chat_response(
|
|||
|
||||
reasoning_duration = block.get("duration", None)
|
||||
|
||||
start_tag = block.get("start_tag", "")
|
||||
end_tag = block.get("end_tag", "")
|
||||
|
||||
if content and not content.endswith("\n"):
|
||||
content += "\n"
|
||||
|
||||
if reasoning_duration is not None:
|
||||
if raw:
|
||||
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
|
||||
content = (
|
||||
f'{content}{start_tag}{block["content"]}{end_tag}\n'
|
||||
)
|
||||
else:
|
||||
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
|
||||
content = (
|
||||
f'{content}{start_tag}{block["content"]}{end_tag}\n'
|
||||
)
|
||||
else:
|
||||
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
|
||||
elif block["type"] == "code_interpreter":
|
||||
attributes = block.get("attributes", {})
|
||||
|
|
@ -1499,26 +1539,30 @@ async def process_chat_response(
|
|||
# Keep content as is - either closing backticks or no backticks
|
||||
content = content_stripped + original_whitespace
|
||||
|
||||
if content and not content.endswith("\n"):
|
||||
content += "\n"
|
||||
|
||||
if output:
|
||||
output = html.escape(json.dumps(output))
|
||||
|
||||
if raw:
|
||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||||
else:
|
||||
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
content = f'{content}<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||||
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||||
else:
|
||||
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
content = f'{content}<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||
|
||||
else:
|
||||
block_content = str(block["content"]).strip()
|
||||
content = f"{content}{block['type']}: {block_content}\n"
|
||||
if block_content:
|
||||
content = f"{content}{block['type']}: {block_content}\n"
|
||||
|
||||
return content.strip()
|
||||
|
||||
def convert_content_blocks_to_messages(content_blocks):
|
||||
def convert_content_blocks_to_messages(content_blocks, raw=False):
|
||||
messages = []
|
||||
|
||||
temp_blocks = []
|
||||
|
|
@ -1527,7 +1571,7 @@ async def process_chat_response(
|
|||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(temp_blocks),
|
||||
"content": serialize_content_blocks(temp_blocks, raw),
|
||||
"tool_calls": block.get("content"),
|
||||
}
|
||||
)
|
||||
|
|
@ -1547,7 +1591,7 @@ async def process_chat_response(
|
|||
temp_blocks.append(block)
|
||||
|
||||
if temp_blocks:
|
||||
content = serialize_content_blocks(temp_blocks)
|
||||
content = serialize_content_blocks(temp_blocks, raw)
|
||||
if content:
|
||||
messages.append(
|
||||
{
|
||||
|
|
@ -1574,13 +1618,25 @@ async def process_chat_response(
|
|||
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
for start_tag, end_tag in tags:
|
||||
# Match start tag e.g., <tag> or <tag attr="value">
|
||||
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>"
|
||||
|
||||
start_tag_pattern = rf"{re.escape(start_tag)}"
|
||||
if start_tag.startswith("<") and start_tag.endswith(">"):
|
||||
# Match start tag e.g., <tag> or <tag attr="value">
|
||||
# remove both '<' and '>' from start_tag
|
||||
# Match start tag with attributes
|
||||
start_tag_pattern = (
|
||||
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
|
||||
)
|
||||
|
||||
match = re.search(start_tag_pattern, content)
|
||||
if match:
|
||||
attr_content = (
|
||||
match.group(1) if match.group(1) else ""
|
||||
) # Ensure it's not None
|
||||
try:
|
||||
attr_content = (
|
||||
match.group(1) if match.group(1) else ""
|
||||
) # Ensure it's not None
|
||||
except:
|
||||
attr_content = ""
|
||||
|
||||
attributes = extract_attributes(
|
||||
attr_content
|
||||
) # Extract attributes safely
|
||||
|
|
@ -1626,8 +1682,13 @@ async def process_chat_response(
|
|||
elif content_blocks[-1]["type"] == content_type:
|
||||
start_tag = content_blocks[-1]["start_tag"]
|
||||
end_tag = content_blocks[-1]["end_tag"]
|
||||
# Match end tag e.g., </tag>
|
||||
end_tag_pattern = rf"<{re.escape(end_tag)}>"
|
||||
|
||||
if end_tag.startswith("<") and end_tag.endswith(">"):
|
||||
# Match end tag e.g., </tag>
|
||||
end_tag_pattern = rf"{re.escape(end_tag)}"
|
||||
else:
|
||||
# Handle cases where end_tag is just a tag name
|
||||
end_tag_pattern = rf"{re.escape(end_tag)}"
|
||||
|
||||
# Check if the content has the end tag
|
||||
if re.search(end_tag_pattern, content):
|
||||
|
|
@ -1699,8 +1760,17 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
# Clean processed content
|
||||
start_tag_pattern = rf"{re.escape(start_tag)}"
|
||||
if start_tag.startswith("<") and start_tag.endswith(">"):
|
||||
# Match start tag e.g., <tag> or <tag attr="value">
|
||||
# remove both '<' and '>' from start_tag
|
||||
# Match start tag with attributes
|
||||
start_tag_pattern = (
|
||||
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
|
||||
)
|
||||
|
||||
content = re.sub(
|
||||
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
|
||||
rf"{start_tag_pattern}(.|\n)*?{re.escape(end_tag)}",
|
||||
"",
|
||||
content,
|
||||
flags=re.DOTALL,
|
||||
|
|
@ -1744,18 +1814,19 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
reasoning_tags = [
|
||||
("think", "/think"),
|
||||
("thinking", "/thinking"),
|
||||
("reason", "/reason"),
|
||||
("reasoning", "/reasoning"),
|
||||
("thought", "/thought"),
|
||||
("Thought", "/Thought"),
|
||||
("|begin_of_thought|", "|end_of_thought|"),
|
||||
("<think>", "</think>"),
|
||||
("<thinking>", "</thinking>"),
|
||||
("<reason>", "</reason>"),
|
||||
("<reasoning>", "</reasoning>"),
|
||||
("<thought>", "</thought>"),
|
||||
("<Thought>", "</Thought>"),
|
||||
("<|begin_of_thought|>", "<|end_of_thought|>"),
|
||||
("◁think▷", "◁/think▷"),
|
||||
]
|
||||
|
||||
code_interpreter_tags = [("code_interpreter", "/code_interpreter")]
|
||||
code_interpreter_tags = [("<code_interpreter>", "</code_interpreter>")]
|
||||
|
||||
solution_tags = [("|begin_of_solution|", "|end_of_solution|")]
|
||||
solution_tags = [("<|begin_of_solution|>", "<|end_of_solution|>")]
|
||||
|
||||
try:
|
||||
for event in events:
|
||||
|
|
@ -1781,6 +1852,30 @@ async def process_chat_response(
|
|||
|
||||
response_tool_calls = []
|
||||
|
||||
delta_count = 0
|
||||
delta_chunk_size = max(
|
||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
||||
int(
|
||||
metadata.get("params", {}).get("stream_delta_chunk_size")
|
||||
or 1
|
||||
),
|
||||
)
|
||||
last_delta_data = None
|
||||
|
||||
async def flush_pending_delta_data(threshold: int = 0):
|
||||
nonlocal delta_count
|
||||
nonlocal last_delta_data
|
||||
|
||||
if delta_count >= threshold and last_delta_data:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": last_delta_data,
|
||||
}
|
||||
)
|
||||
delta_count = 0
|
||||
last_delta_data = None
|
||||
|
||||
async for line in response.body_iterator:
|
||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||
data = line
|
||||
|
|
@ -1820,6 +1915,12 @@ async def process_chat_response(
|
|||
"selectedModelId": model_id,
|
||||
},
|
||||
)
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
else:
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
|
|
@ -1920,8 +2021,8 @@ async def process_chat_response(
|
|||
):
|
||||
reasoning_block = {
|
||||
"type": "reasoning",
|
||||
"start_tag": "think",
|
||||
"end_tag": "/think",
|
||||
"start_tag": "<think>",
|
||||
"end_tag": "</think>",
|
||||
"attributes": {
|
||||
"type": "reasoning_content"
|
||||
},
|
||||
|
|
@ -2028,19 +2129,26 @@ async def process_chat_response(
|
|||
),
|
||||
}
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
if delta:
|
||||
delta_count += 1
|
||||
last_delta_data = data
|
||||
if delta_count >= delta_chunk_size:
|
||||
await flush_pending_delta_data(delta_chunk_size)
|
||||
else:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
done = "data: [DONE]" in line
|
||||
if done:
|
||||
pass
|
||||
else:
|
||||
log.debug("Error: ", e)
|
||||
log.debug(f"Error: {e}")
|
||||
continue
|
||||
await flush_pending_delta_data()
|
||||
|
||||
if content_blocks:
|
||||
# Clean up the last text block
|
||||
|
|
@ -2060,6 +2168,15 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
if content_blocks[-1]["type"] == "reasoning":
|
||||
reasoning_block = content_blocks[-1]
|
||||
if reasoning_block.get("ended_at") is None:
|
||||
reasoning_block["ended_at"] = time.time()
|
||||
reasoning_block["duration"] = int(
|
||||
reasoning_block["ended_at"]
|
||||
- reasoning_block["started_at"]
|
||||
)
|
||||
|
||||
if response_tool_calls:
|
||||
tool_calls.append(response_tool_calls)
|
||||
|
||||
|
|
@ -2072,6 +2189,7 @@ async def process_chat_response(
|
|||
tool_call_retries = 0
|
||||
|
||||
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
||||
|
||||
tool_call_retries += 1
|
||||
|
||||
response_tool_calls = tool_calls.pop(0)
|
||||
|
|
@ -2223,7 +2341,9 @@ async def process_chat_response(
|
|||
"tools": form_data["tools"],
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
*convert_content_blocks_to_messages(content_blocks),
|
||||
*convert_content_blocks_to_messages(
|
||||
content_blocks, True
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -2266,6 +2386,27 @@ async def process_chat_response(
|
|||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
code = content_blocks[-1]["content"]
|
||||
if CODE_INTERPRETER_BLOCKED_MODULES:
|
||||
blocking_code = textwrap.dedent(
|
||||
f"""
|
||||
import builtins
|
||||
|
||||
BLOCKED_MODULES = {CODE_INTERPRETER_BLOCKED_MODULES}
|
||||
|
||||
_real_import = builtins.__import__
|
||||
def restricted_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name.split('.')[0] in BLOCKED_MODULES:
|
||||
importer_name = globals.get('__name__') if globals else None
|
||||
if importer_name == '__main__':
|
||||
raise ImportError(
|
||||
f"Direct import of module {{name}} is restricted."
|
||||
)
|
||||
return _real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
builtins.__import__ = restricted_import
|
||||
"""
|
||||
)
|
||||
code = blocking_code + "\n" + code
|
||||
|
||||
if (
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||||
|
|
@ -2431,7 +2572,7 @@ async def process_chat_response(
|
|||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
await post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
|
|
@ -2468,13 +2609,7 @@ async def process_chat_response(
|
|||
if response.background is not None:
|
||||
await response.background()
|
||||
|
||||
# background_tasks.add_task(response_handler, response, events)
|
||||
task_id, _ = await create_task(
|
||||
request.app.state.redis,
|
||||
response_handler(response, events),
|
||||
id=metadata["chat_id"],
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
return await response_handler(response, events)
|
||||
|
||||
else:
|
||||
# Fallback to the original response
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import hashlib
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
|
|
@ -227,7 +228,7 @@ def openai_chat_chunk_message_template(
|
|||
if tool_calls:
|
||||
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
||||
|
||||
if not content and not tool_calls:
|
||||
if not content and not reasoning_content and not tool_calls:
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
|
||||
if usage:
|
||||
|
|
@ -478,3 +479,46 @@ def convert_logit_bias_input_to_json(user_input):
|
|||
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
||||
logit_bias_json[token] = bias
|
||||
return json.dumps(logit_bias_json)
|
||||
|
||||
|
||||
def freeze(value):
|
||||
"""
|
||||
Freeze a value to make it hashable.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return frozenset((k, freeze(v)) for k, v in value.items())
|
||||
elif isinstance(value, list):
|
||||
return tuple(freeze(v) for v in value)
|
||||
return value
|
||||
|
||||
|
||||
def throttle(interval: float = 10.0):
|
||||
"""
|
||||
Decorator to prevent a function from being called more than once within a specified duration.
|
||||
If the function is called again within the duration, it returns None. To avoid returning
|
||||
different types, the return type of the function should be Optional[T].
|
||||
|
||||
:param interval: Duration in seconds to wait before allowing the function to be called again.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
last_calls = {}
|
||||
lock = threading.Lock()
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if interval is None:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
key = (args, freeze(kwargs))
|
||||
now = time.time()
|
||||
if now - last_calls.get(key, 0) < interval:
|
||||
return None
|
||||
with lock:
|
||||
if now - last_calls.get(key, 0) < interval:
|
||||
return None
|
||||
last_calls[key] = now
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from open_webui.config import (
|
|||
ENABLE_OAUTH_GROUP_CREATION,
|
||||
OAUTH_BLOCKED_GROUPS,
|
||||
OAUTH_ROLES_CLAIM,
|
||||
OAUTH_SUB_CLAIM,
|
||||
OAUTH_GROUPS_CLAIM,
|
||||
OAUTH_EMAIL_CLAIM,
|
||||
OAUTH_PICTURE_CLAIM,
|
||||
|
|
@ -65,6 +66,7 @@ auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMEN
|
|||
auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
|
||||
auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
|
||||
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||
auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM
|
||||
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
|
||||
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||
|
|
@ -88,11 +90,12 @@ class OAuthManager:
|
|||
return self.oauth.create_client(provider_name)
|
||||
|
||||
def get_user_role(self, user, user_data):
|
||||
if user and Users.get_num_users() == 1:
|
||||
user_count = Users.get_num_users()
|
||||
if user and user_count == 1:
|
||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||
log.debug("Assigning the only user the admin role")
|
||||
return "admin"
|
||||
if not user and Users.get_num_users() == 0:
|
||||
if not user and user_count == 0:
|
||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||
log.debug("Assigning the first user the admin role")
|
||||
return "admin"
|
||||
|
|
@ -112,7 +115,13 @@ class OAuthManager:
|
|||
nested_claims = oauth_claim.split(".")
|
||||
for nested_claim in nested_claims:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else []
|
||||
|
||||
oauth_roles = []
|
||||
|
||||
if isinstance(claim_data, list):
|
||||
oauth_roles = claim_data
|
||||
if isinstance(claim_data, str) or isinstance(claim_data, int):
|
||||
oauth_roles = [str(claim_data)]
|
||||
|
||||
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
||||
log.debug(f"User roles from oauth: {oauth_roles}")
|
||||
|
|
@ -352,17 +361,28 @@ class OAuthManager:
|
|||
log.warning(f"OAuth callback error: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
user_data: UserInfo = token.get("userinfo")
|
||||
if not user_data or auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data:
|
||||
if (
|
||||
(not user_data)
|
||||
or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
|
||||
or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
|
||||
):
|
||||
user_data: UserInfo = await client.userinfo(token=token)
|
||||
if not user_data:
|
||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
||||
if auth_manager_config.OAUTH_SUB_CLAIM:
|
||||
sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
||||
else:
|
||||
# Fallback to the default sub claim if not configured
|
||||
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
||||
|
||||
if not sub:
|
||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
provider_sub = f"{provider}@{sub}"
|
||||
|
||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||
email = user_data.get(email_claim, "")
|
||||
# We currently mandate that email addresses are provided
|
||||
|
|
@ -449,8 +469,6 @@ class OAuthManager:
|
|||
log.debug(f"Updated profile picture for user {user.email}")
|
||||
|
||||
if not user:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||
# Check if an existing user with the same email already exists
|
||||
|
|
@ -490,7 +508,7 @@ class OAuthManager:
|
|||
)
|
||||
|
||||
if auth_manager_config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
await post_webhook(
|
||||
WEBUI_NAME,
|
||||
auth_manager_config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
|
|
@ -517,11 +535,19 @@ class OAuthManager:
|
|||
default_permissions=request.app.state.config.USER_PERMISSIONS,
|
||||
)
|
||||
|
||||
redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
|
||||
if redirect_base_url.endswith("/"):
|
||||
redirect_base_url = redirect_base_url[:-1]
|
||||
redirect_url = f"{redirect_base_url}/auth"
|
||||
|
||||
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
||||
|
||||
# Set the cookie token
|
||||
# Redirect back to the frontend with the JWT token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=jwt_token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
httponly=False, # Required for frontend access
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
|
@ -535,11 +561,4 @@ class OAuthManager:
|
|||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Redirect back to the frontend with the JWT token
|
||||
|
||||
redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
|
||||
if redirect_base_url.endswith("/"):
|
||||
redirect_base_url = redirect_base_url[:-1]
|
||||
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
|
||||
|
||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import json
|
|||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(
|
||||
def apply_system_prompt_to_body(
|
||||
system: Optional[str], form_data: dict, metadata: Optional[dict] = None, user=None
|
||||
) -> dict:
|
||||
if not system:
|
||||
|
|
@ -22,15 +22,7 @@ def apply_model_system_prompt_to_body(
|
|||
system = prompt_variables_template(system, variables)
|
||||
|
||||
# Legacy (API Usage)
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
|
||||
system = prompt_template(system, **template_params)
|
||||
system = prompt_template(system, user)
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
|
|
@ -69,6 +61,7 @@ def remove_open_webui_params(params: dict) -> dict:
|
|||
"""
|
||||
open_webui_params = {
|
||||
"stream_response": bool,
|
||||
"stream_delta_chunk_size": int,
|
||||
"function_calling": str,
|
||||
"system": str,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ def install_tool_and_function_dependencies():
|
|||
all_dependencies += f"{dependencies}, "
|
||||
for tool in tool_list:
|
||||
# Only install requirements for admin tools
|
||||
if tool.user.role == "admin":
|
||||
if tool.user and tool.user.role == "admin":
|
||||
frontmatter = extract_frontmatter(replace_imports(tool.content))
|
||||
if dependencies := frontmatter.get("requirements"):
|
||||
all_dependencies += f"{dependencies}, "
|
||||
|
|
|
|||
|
|
@ -1,12 +1,103 @@
|
|||
import socketio
|
||||
import inspect
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import redis
|
||||
|
||||
from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_CONNECTION_CACHE = {}
|
||||
|
||||
|
||||
class SentinelRedisProxy:
|
||||
def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
|
||||
self._sentinel = sentinel
|
||||
self._service = service
|
||||
self._kw = kw
|
||||
self._async_mode = async_mode
|
||||
|
||||
def _master(self):
|
||||
return self._sentinel.master_for(self._service, **self._kw)
|
||||
|
||||
def __getattr__(self, item):
|
||||
master = self._master()
|
||||
orig_attr = getattr(master, item)
|
||||
|
||||
if not callable(orig_attr):
|
||||
return orig_attr
|
||||
|
||||
FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
|
||||
if item in FACTORY_METHODS:
|
||||
return orig_attr
|
||||
|
||||
if self._async_mode:
|
||||
|
||||
async def _wrapped(*args, **kwargs):
|
||||
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
|
||||
try:
|
||||
method = getattr(self._master(), item)
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
except (
|
||||
redis.exceptions.ConnectionError,
|
||||
redis.exceptions.ReadOnlyError,
|
||||
) as e:
|
||||
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
|
||||
log.debug(
|
||||
"Redis sentinel fail-over (%s). Retry %s/%s",
|
||||
type(e).__name__,
|
||||
i + 1,
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
||||
)
|
||||
continue
|
||||
log.error(
|
||||
"Redis operation failed after %s retries: %s",
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
||||
e,
|
||||
)
|
||||
raise e from e
|
||||
|
||||
return _wrapped
|
||||
|
||||
else:
|
||||
|
||||
def _wrapped(*args, **kwargs):
|
||||
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
|
||||
try:
|
||||
method = getattr(self._master(), item)
|
||||
return method(*args, **kwargs)
|
||||
except (
|
||||
redis.exceptions.ConnectionError,
|
||||
redis.exceptions.ReadOnlyError,
|
||||
) as e:
|
||||
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
|
||||
log.debug(
|
||||
"Redis sentinel fail-over (%s). Retry %s/%s",
|
||||
type(e).__name__,
|
||||
i + 1,
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
||||
)
|
||||
continue
|
||||
log.error(
|
||||
"Redis operation failed after %s retries: %s",
|
||||
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
||||
e,
|
||||
)
|
||||
raise e from e
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def parse_redis_service_url(redis_url):
|
||||
parsed_url = urlparse(redis_url)
|
||||
if parsed_url.scheme != "redis":
|
||||
raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
|
||||
if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
|
||||
raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
|
||||
|
||||
return {
|
||||
"username": parsed_url.username or None,
|
||||
|
|
@ -18,8 +109,25 @@ def parse_redis_service_url(redis_url):
|
|||
|
||||
|
||||
def get_redis_connection(
|
||||
redis_url, redis_sentinels, async_mode=False, decode_responses=True
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster=False,
|
||||
async_mode=False,
|
||||
decode_responses=True,
|
||||
):
|
||||
|
||||
cache_key = (
|
||||
redis_url,
|
||||
tuple(redis_sentinels) if redis_sentinels else (),
|
||||
async_mode,
|
||||
decode_responses,
|
||||
)
|
||||
|
||||
if cache_key in _CONNECTION_CACHE:
|
||||
return _CONNECTION_CACHE[cache_key]
|
||||
|
||||
connection = None
|
||||
|
||||
if async_mode:
|
||||
import redis.asyncio as redis
|
||||
|
||||
|
|
@ -34,11 +142,19 @@ def get_redis_connection(
|
|||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
connection = SentinelRedisProxy(
|
||||
sentinel,
|
||||
redis_config["service"],
|
||||
async_mode=async_mode,
|
||||
)
|
||||
elif redis_cluster:
|
||||
if not redis_url:
|
||||
raise ValueError("Redis URL must be provided for cluster mode.")
|
||||
return redis.cluster.RedisCluster.from_url(
|
||||
redis_url, decode_responses=decode_responses
|
||||
)
|
||||
elif redis_url:
|
||||
return redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
connection = redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
import redis
|
||||
|
||||
|
|
@ -52,11 +168,24 @@ def get_redis_connection(
|
|||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
connection = SentinelRedisProxy(
|
||||
sentinel,
|
||||
redis_config["service"],
|
||||
async_mode=async_mode,
|
||||
)
|
||||
elif redis_cluster:
|
||||
if not redis_url:
|
||||
raise ValueError("Redis URL must be provided for cluster mode.")
|
||||
return redis.cluster.RedisCluster.from_url(
|
||||
redis_url, decode_responses=decode_responses
|
||||
)
|
||||
elif redis_url:
|
||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
connection = redis.Redis.from_url(
|
||||
redis_url, decode_responses=decode_responses
|
||||
)
|
||||
|
||||
_CONNECTION_CACHE[cache_key] = connection
|
||||
return connection
|
||||
|
||||
|
||||
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
||||
|
|
|
|||
|
|
@ -6,18 +6,17 @@ from open_webui.utils.misc import (
|
|||
)
|
||||
|
||||
|
||||
def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
|
||||
def convert_ollama_tool_call_to_openai(tool_calls: list) -> list:
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
function = tool_call.get("function", {})
|
||||
openai_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"index": tool_call.get("index", function.get("index", 0)),
|
||||
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.dumps(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
"name": function.get("name", ""),
|
||||
"arguments": json.dumps(function.get("arguments", {})),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import math
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
import uuid
|
||||
|
||||
|
||||
|
|
@ -38,9 +38,46 @@ def prompt_variables_template(template: str, variables: dict[str, str]) -> str:
|
|||
return template
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||
) -> str:
|
||||
def prompt_template(template: str, user: Optional[Any] = None) -> str:
|
||||
|
||||
USER_VARIABLES = {}
|
||||
|
||||
if user:
|
||||
if hasattr(user, "model_dump"):
|
||||
user = user.model_dump()
|
||||
|
||||
if isinstance(user, dict):
|
||||
user_info = user.get("info", {}) or {}
|
||||
birth_date = user.get("date_of_birth")
|
||||
age = None
|
||||
|
||||
if birth_date:
|
||||
try:
|
||||
# If birth_date is str, convert to datetime
|
||||
if isinstance(birth_date, str):
|
||||
birth_date = datetime.strptime(birth_date, "%Y-%m-%d")
|
||||
|
||||
today = datetime.now()
|
||||
age = (
|
||||
today.year
|
||||
- birth_date.year
|
||||
- (
|
||||
(today.month, today.day)
|
||||
< (birth_date.month, birth_date.day)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
USER_VARIABLES = {
|
||||
"name": str(user.get("name")),
|
||||
"location": str(user_info.get("location")),
|
||||
"bio": str(user.get("bio")),
|
||||
"gender": str(user.get("gender")),
|
||||
"birth_date": str(birth_date),
|
||||
"age": str(age),
|
||||
}
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
|
||||
|
|
@ -56,19 +93,20 @@ def prompt_template(
|
|||
)
|
||||
template = template.replace("{{CURRENT_WEEKDAY}}", formatted_weekday)
|
||||
|
||||
if user_name:
|
||||
# Replace {{USER_NAME}} in the template with the user's name
|
||||
template = template.replace("{{USER_NAME}}", user_name)
|
||||
else:
|
||||
# Replace {{USER_NAME}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_NAME}}", "Unknown")
|
||||
|
||||
if user_location:
|
||||
# Replace {{USER_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{USER_LOCATION}}", user_location)
|
||||
else:
|
||||
# Replace {{USER_LOCATION}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_LOCATION}}", "Unknown")
|
||||
template = template.replace("{{USER_NAME}}", USER_VARIABLES.get("name", "Unknown"))
|
||||
template = template.replace("{{USER_BIO}}", USER_VARIABLES.get("bio", "Unknown"))
|
||||
template = template.replace(
|
||||
"{{USER_GENDER}}", USER_VARIABLES.get("gender", "Unknown")
|
||||
)
|
||||
template = template.replace(
|
||||
"{{USER_BIRTH_DATE}}", USER_VARIABLES.get("birth_date", "Unknown")
|
||||
)
|
||||
template = template.replace(
|
||||
"{{USER_AGE}}", str(USER_VARIABLES.get("age", "Unknown"))
|
||||
)
|
||||
template = template.replace(
|
||||
"{{USER_LOCATION}}", USER_VARIABLES.get("location", "Unknown")
|
||||
)
|
||||
|
||||
return template
|
||||
|
||||
|
|
@ -189,90 +227,56 @@ def rag_template(template: str, context: str, query: str):
|
|||
|
||||
|
||||
def title_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def follow_up_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def tags_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def image_prompt_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def emoji_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
template: str, prompt: str, user: Optional[Any] = None
|
||||
) -> str:
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
|
||||
return template
|
||||
|
||||
|
|
@ -282,38 +286,24 @@ def autocomplete_generation_template(
|
|||
prompt: str,
|
||||
messages: Optional[list[dict]] = None,
|
||||
type: Optional[str] = None,
|
||||
user: Optional[dict] = None,
|
||||
user: Optional[Any] = None,
|
||||
) -> str:
|
||||
template = template.replace("{{TYPE}}", type if type else "")
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def query_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
import threading
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
|
||||
|
||||
class LazyBatchSpanProcessor(BatchSpanProcessor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.done = True
|
||||
with self.condition:
|
||||
self.condition.notify_all()
|
||||
self.worker_thread.join()
|
||||
self.done = False
|
||||
self.worker_thread = None
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if self.worker_thread is None:
|
||||
self.worker_thread = threading.Thread(
|
||||
name=self.__class__.__name__, target=self.worker, daemon=True
|
||||
)
|
||||
self.worker_thread.start()
|
||||
super().on_end(span)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.done = True
|
||||
with self.condition:
|
||||
self.condition.notify_all()
|
||||
if self.worker_thread:
|
||||
self.worker_thread.join()
|
||||
self.span_exporter.shutdown()
|
||||
53
backend/open_webui/utils/telemetry/logs.py
Normal file
53
backend/open_webui/utils/telemetry/logs.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import logging
|
||||
from base64 import b64encode
|
||||
from opentelemetry.sdk._logs import (
|
||||
LoggingHandler,
|
||||
LoggerProvider,
|
||||
)
|
||||
from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter
|
||||
from opentelemetry.exporter.otlp.proto.http._log_exporter import (
|
||||
OTLPLogExporter as HttpOTLPLogExporter,
|
||||
)
|
||||
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
|
||||
from opentelemetry._logs import set_logger_provider
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||
from open_webui.env import (
|
||||
OTEL_SERVICE_NAME,
|
||||
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT,
|
||||
OTEL_LOGS_EXPORTER_OTLP_INSECURE,
|
||||
OTEL_LOGS_BASIC_AUTH_USERNAME,
|
||||
OTEL_LOGS_BASIC_AUTH_PASSWORD,
|
||||
OTEL_LOGS_OTLP_SPAN_EXPORTER,
|
||||
)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
headers = []
|
||||
if OTEL_LOGS_BASIC_AUTH_USERNAME and OTEL_LOGS_BASIC_AUTH_PASSWORD:
|
||||
auth_string = f"{OTEL_LOGS_BASIC_AUTH_USERNAME}:{OTEL_LOGS_BASIC_AUTH_PASSWORD}"
|
||||
auth_header = b64encode(auth_string.encode()).decode()
|
||||
headers = [("authorization", f"Basic {auth_header}")]
|
||||
resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
||||
|
||||
if OTEL_LOGS_OTLP_SPAN_EXPORTER == "http":
|
||||
exporter = HttpOTLPLogExporter(
|
||||
endpoint=OTEL_LOGS_EXPORTER_OTLP_ENDPOINT,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
exporter = OTLPLogExporter(
|
||||
endpoint=OTEL_LOGS_EXPORTER_OTLP_ENDPOINT,
|
||||
insecure=OTEL_LOGS_EXPORTER_OTLP_INSECURE,
|
||||
headers=headers,
|
||||
)
|
||||
logger_provider = LoggerProvider(resource=resource)
|
||||
set_logger_provider(logger_provider)
|
||||
|
||||
logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter))
|
||||
|
||||
otel_handler = LoggingHandler(logger_provider=logger_provider)
|
||||
|
||||
return otel_handler
|
||||
|
||||
|
||||
otel_handler = setup_logging()
|
||||
|
|
@ -19,35 +19,69 @@ from __future__ import annotations
|
|||
|
||||
import time
|
||||
from typing import Dict, List, Sequence, Any
|
||||
from base64 import b64encode
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from opentelemetry import metrics
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
|
||||
OTLPMetricExporter,
|
||||
)
|
||||
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||
OTLPMetricExporter as OTLPHttpMetricExporter,
|
||||
)
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.view import View
|
||||
from opentelemetry.sdk.metrics.export import (
|
||||
PeriodicExportingMetricReader,
|
||||
)
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||
|
||||
from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
|
||||
from open_webui.env import (
|
||||
OTEL_SERVICE_NAME,
|
||||
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT,
|
||||
OTEL_METRICS_BASIC_AUTH_USERNAME,
|
||||
OTEL_METRICS_BASIC_AUTH_PASSWORD,
|
||||
OTEL_METRICS_OTLP_SPAN_EXPORTER,
|
||||
OTEL_METRICS_EXPORTER_OTLP_INSECURE,
|
||||
)
|
||||
from open_webui.socket.main import get_active_user_ids
|
||||
from open_webui.models.users import Users
|
||||
|
||||
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
|
||||
|
||||
|
||||
def _build_meter_provider() -> MeterProvider:
|
||||
def _build_meter_provider(resource: Resource) -> MeterProvider:
|
||||
"""Return a configured MeterProvider."""
|
||||
headers = []
|
||||
if OTEL_METRICS_BASIC_AUTH_USERNAME and OTEL_METRICS_BASIC_AUTH_PASSWORD:
|
||||
auth_string = (
|
||||
f"{OTEL_METRICS_BASIC_AUTH_USERNAME}:{OTEL_METRICS_BASIC_AUTH_PASSWORD}"
|
||||
)
|
||||
auth_header = b64encode(auth_string.encode()).decode()
|
||||
headers = [("authorization", f"Basic {auth_header}")]
|
||||
|
||||
# Periodic reader pushes metrics over OTLP/gRPC to collector
|
||||
readers: List[PeriodicExportingMetricReader] = [
|
||||
PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT),
|
||||
export_interval_millis=_EXPORT_INTERVAL_MILLIS,
|
||||
)
|
||||
]
|
||||
if OTEL_METRICS_OTLP_SPAN_EXPORTER == "http":
|
||||
readers: List[PeriodicExportingMetricReader] = [
|
||||
PeriodicExportingMetricReader(
|
||||
OTLPHttpMetricExporter(
|
||||
endpoint=OTEL_METRICS_EXPORTER_OTLP_ENDPOINT, headers=headers
|
||||
),
|
||||
export_interval_millis=_EXPORT_INTERVAL_MILLIS,
|
||||
)
|
||||
]
|
||||
else:
|
||||
readers: List[PeriodicExportingMetricReader] = [
|
||||
PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=OTEL_METRICS_EXPORTER_OTLP_ENDPOINT,
|
||||
insecure=OTEL_METRICS_EXPORTER_OTLP_INSECURE,
|
||||
headers=headers,
|
||||
),
|
||||
export_interval_millis=_EXPORT_INTERVAL_MILLIS,
|
||||
)
|
||||
]
|
||||
|
||||
# Optional view to limit cardinality: drop user-agent etc.
|
||||
views: List[View] = [
|
||||
|
|
@ -59,20 +93,26 @@ def _build_meter_provider() -> MeterProvider:
|
|||
instrument_name="http.server.requests",
|
||||
attribute_keys=["http.method", "http.route", "http.status_code"],
|
||||
),
|
||||
View(
|
||||
instrument_name="webui.users.total",
|
||||
),
|
||||
View(
|
||||
instrument_name="webui.users.active",
|
||||
),
|
||||
]
|
||||
|
||||
provider = MeterProvider(
|
||||
resource=Resource.create({SERVICE_NAME: OTEL_SERVICE_NAME}),
|
||||
resource=resource,
|
||||
metric_readers=list(readers),
|
||||
views=views,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def setup_metrics(app: FastAPI) -> None:
|
||||
def setup_metrics(app: FastAPI, resource: Resource) -> None:
|
||||
"""Attach OTel metrics middleware to *app* and initialise provider."""
|
||||
|
||||
metrics.set_meter_provider(_build_meter_provider())
|
||||
metrics.set_meter_provider(_build_meter_provider(resource))
|
||||
meter = metrics.get_meter(__name__)
|
||||
|
||||
# Instruments
|
||||
|
|
@ -87,6 +127,38 @@ def setup_metrics(app: FastAPI) -> None:
|
|||
unit="ms",
|
||||
)
|
||||
|
||||
def observe_active_users(
|
||||
options: metrics.CallbackOptions,
|
||||
) -> Sequence[metrics.Observation]:
|
||||
return [
|
||||
metrics.Observation(
|
||||
value=len(get_active_user_ids()),
|
||||
)
|
||||
]
|
||||
|
||||
def observe_total_registered_users(
|
||||
options: metrics.CallbackOptions,
|
||||
) -> Sequence[metrics.Observation]:
|
||||
return [
|
||||
metrics.Observation(
|
||||
value=len(Users.get_users()["users"]),
|
||||
)
|
||||
]
|
||||
|
||||
meter.create_observable_gauge(
|
||||
name="webui.users.total",
|
||||
description="Total number of registered users",
|
||||
unit="users",
|
||||
callbacks=[observe_total_registered_users],
|
||||
)
|
||||
|
||||
meter.create_observable_gauge(
|
||||
name="webui.users.active",
|
||||
description="Number of currently active users",
|
||||
unit="users",
|
||||
callbacks=[observe_active_users],
|
||||
)
|
||||
|
||||
# FastAPI middleware
|
||||
@app.middleware("http")
|
||||
async def _metrics_middleware(request: Request, call_next):
|
||||
|
|
|
|||
|
|
@ -1,21 +1,23 @@
|
|||
from fastapi import FastAPI
|
||||
from opentelemetry import trace
|
||||
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
OTLPSpanExporter as HttpOTLPSpanExporter,
|
||||
)
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from sqlalchemy import Engine
|
||||
from base64 import b64encode
|
||||
|
||||
from open_webui.utils.telemetry.exporters import LazyBatchSpanProcessor
|
||||
from open_webui.utils.telemetry.instrumentors import Instrumentor
|
||||
from open_webui.utils.telemetry.metrics import setup_metrics
|
||||
from open_webui.env import (
|
||||
OTEL_SERVICE_NAME,
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT,
|
||||
OTEL_EXPORTER_OTLP_INSECURE,
|
||||
ENABLE_OTEL_TRACES,
|
||||
ENABLE_OTEL_METRICS,
|
||||
OTEL_BASIC_AUTH_USERNAME,
|
||||
OTEL_BASIC_AUTH_PASSWORD,
|
||||
|
|
@ -25,35 +27,32 @@ from open_webui.env import (
|
|||
|
||||
def setup(app: FastAPI, db_engine: Engine):
|
||||
# set up trace
|
||||
trace.set_tracer_provider(
|
||||
TracerProvider(
|
||||
resource=Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
||||
)
|
||||
)
|
||||
resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
||||
if ENABLE_OTEL_TRACES:
|
||||
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||
|
||||
# Add basic auth header only if both username and password are not empty
|
||||
headers = []
|
||||
if OTEL_BASIC_AUTH_USERNAME and OTEL_BASIC_AUTH_PASSWORD:
|
||||
auth_string = f"{OTEL_BASIC_AUTH_USERNAME}:{OTEL_BASIC_AUTH_PASSWORD}"
|
||||
auth_header = b64encode(auth_string.encode()).decode()
|
||||
headers = [("authorization", f"Basic {auth_header}")]
|
||||
# Add basic auth header only if both username and password are not empty
|
||||
headers = []
|
||||
if OTEL_BASIC_AUTH_USERNAME and OTEL_BASIC_AUTH_PASSWORD:
|
||||
auth_string = f"{OTEL_BASIC_AUTH_USERNAME}:{OTEL_BASIC_AUTH_PASSWORD}"
|
||||
auth_header = b64encode(auth_string.encode()).decode()
|
||||
headers = [("authorization", f"Basic {auth_header}")]
|
||||
|
||||
# otlp export
|
||||
if OTEL_OTLP_SPAN_EXPORTER == "http":
|
||||
exporter = HttpOTLPSpanExporter(
|
||||
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
||||
insecure=OTEL_EXPORTER_OTLP_INSECURE,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
exporter = OTLPSpanExporter(
|
||||
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
||||
insecure=OTEL_EXPORTER_OTLP_INSECURE,
|
||||
headers=headers,
|
||||
)
|
||||
trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
|
||||
Instrumentor(app=app, db_engine=db_engine).instrument()
|
||||
# otlp export
|
||||
if OTEL_OTLP_SPAN_EXPORTER == "http":
|
||||
exporter = HttpOTLPSpanExporter(
|
||||
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
exporter = OTLPSpanExporter(
|
||||
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
||||
insecure=OTEL_EXPORTER_OTLP_INSECURE,
|
||||
headers=headers,
|
||||
)
|
||||
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(exporter))
|
||||
Instrumentor(app=app, db_engine=db_engine).instrument()
|
||||
|
||||
# set up metrics only if enabled
|
||||
if ENABLE_OTEL_METRICS:
|
||||
setup_metrics(app)
|
||||
setup_metrics(app, resource)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import inspect
|
|||
import aiohttp
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
|
@ -38,6 +39,7 @@ from open_webui.models.users import UserModel
|
|||
from open_webui.utils.plugin import load_tool_module_by_id
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
)
|
||||
|
|
@ -55,19 +57,38 @@ def get_async_tool_function_and_apply_extra_params(
|
|||
extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
|
||||
partial_func = partial(function, **extra_params)
|
||||
|
||||
# Remove the 'frozen' keyword arguments from the signature
|
||||
# python-genai uses the signature to infer the tool properties for native function calling
|
||||
parameters = []
|
||||
for name, parameter in sig.parameters.items():
|
||||
# Exclude keyword arguments that are frozen
|
||||
if name in extra_params:
|
||||
continue
|
||||
# Keep remaining parameters
|
||||
parameters.append(parameter)
|
||||
|
||||
new_sig = inspect.Signature(
|
||||
parameters=parameters, return_annotation=sig.return_annotation
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(function):
|
||||
update_wrapper(partial_func, function)
|
||||
return partial_func
|
||||
# wrap the functools.partial as python-genai has trouble with it
|
||||
# https://github.com/googleapis/python-genai/issues/907
|
||||
async def new_function(*args, **kwargs):
|
||||
return await partial_func(*args, **kwargs)
|
||||
|
||||
else:
|
||||
# Make it a coroutine function
|
||||
# Make it a coroutine function when it is not already
|
||||
async def new_function(*args, **kwargs):
|
||||
return partial_func(*args, **kwargs)
|
||||
|
||||
update_wrapper(new_function, function)
|
||||
return new_function
|
||||
update_wrapper(new_function, function)
|
||||
new_function.__signature__ = new_sig
|
||||
|
||||
return new_function
|
||||
|
||||
|
||||
def get_tools(
|
||||
async def get_tools(
|
||||
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||
) -> dict[str, dict]:
|
||||
tools_dict = {}
|
||||
|
|
@ -76,18 +97,24 @@ def get_tools(
|
|||
tool = Tools.get_tool_by_id(tool_id)
|
||||
if tool is None:
|
||||
if tool_id.startswith("server:"):
|
||||
server_idx = int(tool_id.split(":")[1])
|
||||
tool_server_connection = (
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
|
||||
)
|
||||
server_id = tool_id.split(":")[1]
|
||||
|
||||
tool_server_data = None
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
if server["idx"] == server_idx:
|
||||
for server in await get_tool_servers(request):
|
||||
if server["id"] == server_id:
|
||||
tool_server_data = server
|
||||
break
|
||||
assert tool_server_data is not None
|
||||
specs = tool_server_data.get("specs", [])
|
||||
|
||||
if tool_server_data is None:
|
||||
log.warning(f"Tool server data not found for {server_id}")
|
||||
continue
|
||||
|
||||
tool_server_idx = tool_server_data.get("idx", 0)
|
||||
tool_server_connection = (
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS[tool_server_idx]
|
||||
)
|
||||
|
||||
specs = tool_server_data.get("specs", [])
|
||||
for spec in specs:
|
||||
function_name = spec["name"]
|
||||
|
||||
|
|
@ -126,14 +153,15 @@ def get_tools(
|
|||
"spec": spec,
|
||||
}
|
||||
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if function_name in tools_dict:
|
||||
# Handle function name collisions
|
||||
while function_name in tools_dict:
|
||||
log.warning(
|
||||
f"Tool {function_name} already exists in another tools!"
|
||||
)
|
||||
log.warning(f"Discarding {tool_id}.{function_name}")
|
||||
else:
|
||||
tools_dict[function_name] = tool_dict
|
||||
# Prepend server ID to function name
|
||||
function_name = f"{server_id}_{function_name}"
|
||||
|
||||
tools_dict[function_name] = tool_dict
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
|
|
@ -193,14 +221,15 @@ def get_tools(
|
|||
},
|
||||
}
|
||||
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if function_name in tools_dict:
|
||||
# Handle function name collisions
|
||||
while function_name in tools_dict:
|
||||
log.warning(
|
||||
f"Tool {function_name} already exists in another tools!"
|
||||
)
|
||||
log.warning(f"Discarding {tool_id}.{function_name}")
|
||||
else:
|
||||
tools_dict[function_name] = tool_dict
|
||||
# Prepend tool ID to function name
|
||||
function_name = f"{tool_id}_{function_name}"
|
||||
|
||||
tools_dict[function_name] = tool_dict
|
||||
|
||||
return tools_dict
|
||||
|
||||
|
|
@ -283,15 +312,15 @@ def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
|
|||
|
||||
field_defs = {}
|
||||
for name, param in parameters.items():
|
||||
|
||||
type_hint = type_hints.get(name, Any)
|
||||
default_value = param.default if param.default is not param.empty else ...
|
||||
|
||||
param_description = function_param_descriptions.get(name, None)
|
||||
|
||||
if param_description:
|
||||
field_defs[name] = type_hint, Field(
|
||||
default_value, description=param_description
|
||||
field_defs[name] = (
|
||||
type_hint,
|
||||
Field(default_value, description=param_description),
|
||||
)
|
||||
else:
|
||||
field_defs[name] = type_hint, default_value
|
||||
|
|
@ -377,7 +406,6 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
|||
for method, operation in methods.items():
|
||||
if operation.get("operationId"):
|
||||
tool = {
|
||||
"type": "function",
|
||||
"name": operation.get("operationId"),
|
||||
"description": operation.get(
|
||||
"description",
|
||||
|
|
@ -399,10 +427,16 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
|||
description += (
|
||||
f". Possible values: {', '.join(param_schema.get('enum'))}"
|
||||
)
|
||||
tool["parameters"]["properties"][param_name] = {
|
||||
param_property = {
|
||||
"type": param_schema.get("type"),
|
||||
"description": description,
|
||||
}
|
||||
|
||||
# Include items property for array types (required by OpenAI)
|
||||
if param_schema.get("type") == "array" and "items" in param_schema:
|
||||
param_property["items"] = param_schema["items"]
|
||||
|
||||
tool["parameters"]["properties"][param_name] = param_property
|
||||
if param.get("required"):
|
||||
tool["parameters"]["required"].append(param_name)
|
||||
|
||||
|
|
@ -437,6 +471,34 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
|||
return tool_payload
|
||||
|
||||
|
||||
async def set_tool_servers(request: Request):
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
if request.app.state.redis is not None:
|
||||
await request.app.state.redis.set(
|
||||
"tool_servers", json.dumps(request.app.state.TOOL_SERVERS)
|
||||
)
|
||||
|
||||
return request.app.state.TOOL_SERVERS
|
||||
|
||||
|
||||
async def get_tool_servers(request: Request):
|
||||
tool_servers = []
|
||||
if request.app.state.redis is not None:
|
||||
try:
|
||||
tool_servers = json.loads(await request.app.state.redis.get("tool_servers"))
|
||||
request.app.state.TOOL_SERVERS = tool_servers
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching tool_servers from Redis: {e}")
|
||||
|
||||
if not tool_servers:
|
||||
tool_servers = await set_tool_servers(request)
|
||||
|
||||
return tool_servers
|
||||
|
||||
|
||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
|
|
@ -489,15 +551,7 @@ async def get_tool_servers_data(
|
|||
if server.get("config", {}).get("enable"):
|
||||
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
|
||||
openapi_path = server.get("path", "openapi.json")
|
||||
if "://" in openapi_path:
|
||||
# If it contains "://", it's a full URL
|
||||
full_url = openapi_path
|
||||
else:
|
||||
if not openapi_path.startswith("/"):
|
||||
# Ensure the path starts with a slash
|
||||
openapi_path = f"/{openapi_path}"
|
||||
|
||||
full_url = f"{server.get('url')}{openapi_path}"
|
||||
full_url = get_tool_server_url(server.get("url"), openapi_path)
|
||||
|
||||
info = server.get("info", {})
|
||||
|
||||
|
|
@ -508,11 +562,16 @@ async def get_tool_servers_data(
|
|||
token = server.get("key", "")
|
||||
elif auth_type == "session":
|
||||
token = session_token
|
||||
server_entries.append((idx, server, full_url, info, token))
|
||||
|
||||
id = info.get("id")
|
||||
if not id:
|
||||
id = str(idx)
|
||||
|
||||
server_entries.append((id, idx, server, full_url, info, token))
|
||||
|
||||
# Create async tasks to fetch data
|
||||
tasks = [
|
||||
get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries
|
||||
get_tool_server_data(token, url) for (_, _, _, url, _, token) in server_entries
|
||||
]
|
||||
|
||||
# Execute tasks concurrently
|
||||
|
|
@ -520,7 +579,7 @@ async def get_tool_servers_data(
|
|||
|
||||
# Build final results with index and server metadata
|
||||
results = []
|
||||
for (idx, server, url, info, _), response in zip(server_entries, responses):
|
||||
for (id, idx, server, url, info, _), response in zip(server_entries, responses):
|
||||
if isinstance(response, Exception):
|
||||
log.error(f"Failed to connect to {url} OpenAPI tool server")
|
||||
continue
|
||||
|
|
@ -528,6 +587,8 @@ async def get_tool_servers_data(
|
|||
openapi_data = response.get("openapi", {})
|
||||
|
||||
if info and isinstance(openapi_data, dict):
|
||||
openapi_data["info"] = openapi_data.get("info", {})
|
||||
|
||||
if "name" in info:
|
||||
openapi_data["info"]["title"] = info.get("name", "Tool Server")
|
||||
|
||||
|
|
@ -536,6 +597,7 @@ async def get_tool_servers_data(
|
|||
|
||||
results.append(
|
||||
{
|
||||
"id": str(id),
|
||||
"idx": idx,
|
||||
"url": server.get("url"),
|
||||
"openapi": openapi_data,
|
||||
|
|
@ -614,7 +676,9 @@ async def execute_tool_server(
|
|||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
) as session:
|
||||
request_method = getattr(session, http_method.lower())
|
||||
|
||||
if http_method in ["post", "put", "patch"]:
|
||||
|
|
@ -627,7 +691,13 @@ async def execute_tool_server(
|
|||
if response.status >= 400:
|
||||
text = await response.text()
|
||||
raise Exception(f"HTTP error {response.status}: {text}")
|
||||
return await response.json()
|
||||
|
||||
try:
|
||||
response_data = await response.json()
|
||||
except Exception:
|
||||
response_data = await response.text()
|
||||
|
||||
return response_data
|
||||
else:
|
||||
async with request_method(
|
||||
final_url,
|
||||
|
|
@ -637,9 +707,28 @@ async def execute_tool_server(
|
|||
if response.status >= 400:
|
||||
text = await response.text()
|
||||
raise Exception(f"HTTP error {response.status}: {text}")
|
||||
return await response.json()
|
||||
|
||||
try:
|
||||
response_data = await response.json()
|
||||
except Exception:
|
||||
response_data = await response.text()
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as err:
|
||||
error = str(err)
|
||||
log.exception(f"API Request Error: {error}")
|
||||
return {"error": error}
|
||||
|
||||
|
||||
def get_tool_server_url(url: Optional[str], path: str) -> str:
|
||||
"""
|
||||
Build the full URL for a tool server, given a base url and a path.
|
||||
"""
|
||||
if "://" in path:
|
||||
# If it contains "://", it's a full URL
|
||||
return path
|
||||
if not path.startswith("/"):
|
||||
# Ensure the path starts with a slash
|
||||
path = f"/{path}"
|
||||
return f"{url}{path}"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import aiohttp
|
||||
|
||||
import requests
|
||||
from open_webui.config import WEBUI_FAVICON_URL
|
||||
from open_webui.env import SRC_LOG_LEVELS, VERSION
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
||||
|
||||
|
||||
def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
|
||||
async def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
|
||||
try:
|
||||
log.debug(f"post_webhook: {url}, {message}, {event_data}")
|
||||
payload = {}
|
||||
|
|
@ -51,9 +51,12 @@ def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
|
|||
payload = {**event_data}
|
||||
|
||||
log.debug(f"payload: {payload}")
|
||||
r = requests.post(url, json=payload)
|
||||
r.raise_for_status()
|
||||
log.debug(f"r.text: {r.text}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=payload) as r:
|
||||
r_text = await r.text()
|
||||
r.raise_for_status()
|
||||
log.debug(f"r.text: {r_text}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ passlib[bcrypt]==1.7.4
|
|||
cryptography
|
||||
|
||||
requests==2.32.4
|
||||
aiohttp==3.11.11
|
||||
aiohttp==3.12.15
|
||||
async-timeout
|
||||
aiocache
|
||||
aiofiles
|
||||
|
|
@ -27,7 +27,7 @@ bcrypt==4.3.0
|
|||
|
||||
pymongo
|
||||
redis
|
||||
boto3==1.35.53
|
||||
boto3==1.40.5
|
||||
|
||||
argon2-cffi==23.1.0
|
||||
APScheduler==3.10.4
|
||||
|
|
@ -42,26 +42,30 @@ asgiref==3.8.1
|
|||
# AI libraries
|
||||
openai
|
||||
anthropic
|
||||
google-genai==1.15.0
|
||||
google-genai==1.28.0
|
||||
google-generativeai==0.8.5
|
||||
tiktoken
|
||||
|
||||
langchain==0.3.26
|
||||
langchain-community==0.3.26
|
||||
|
||||
fake-useragent==2.1.0
|
||||
fake-useragent==2.2.0
|
||||
chromadb==0.6.3
|
||||
posthog==5.4.0
|
||||
pymilvus==2.5.0
|
||||
qdrant-client==1.14.3
|
||||
opensearch-py==2.8.0
|
||||
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
||||
elasticsearch==9.0.1
|
||||
pinecone==6.0.2
|
||||
oracledb==3.2.0
|
||||
|
||||
av==14.0.1 # Caution: Set due to FATAL FIPS SELFTEST FAILURE, see discussion https://github.com/open-webui/open-webui/discussions/15720
|
||||
transformers
|
||||
sentence-transformers==4.1.0
|
||||
accelerate
|
||||
colbert-ai==0.2.21
|
||||
pyarrow==20.0.0
|
||||
einops==0.8.1
|
||||
|
||||
|
||||
|
|
@ -73,7 +77,7 @@ docx2txt==0.8
|
|||
python-pptx==1.0.2
|
||||
unstructured==0.16.17
|
||||
nltk==3.9.1
|
||||
Markdown==3.7
|
||||
Markdown==3.8.2
|
||||
pypandoc==1.15
|
||||
pandas==2.2.3
|
||||
openpyxl==3.1.5
|
||||
|
|
@ -85,7 +89,7 @@ sentencepiece
|
|||
soundfile==0.13.1
|
||||
azure-ai-documentintelligence==1.0.2
|
||||
|
||||
pillow==11.2.1
|
||||
pillow==11.3.0
|
||||
opencv-python-headless==4.11.0.86
|
||||
rapidocr-onnxruntime==1.4.4
|
||||
rank-bm25==0.2.2
|
||||
|
|
@ -95,7 +99,7 @@ onnxruntime==1.20.1
|
|||
faster-whisper==1.1.1
|
||||
|
||||
PyJWT[crypto]==2.10.1
|
||||
authlib==1.4.1
|
||||
authlib==1.6.1
|
||||
|
||||
black==25.1.0
|
||||
langfuse==2.44.0
|
||||
|
|
@ -132,14 +136,14 @@ firecrawl-py==1.12.0
|
|||
tencentcloud-sdk-python==3.0.1336
|
||||
|
||||
## Trace
|
||||
opentelemetry-api==1.32.1
|
||||
opentelemetry-sdk==1.32.1
|
||||
opentelemetry-exporter-otlp==1.32.1
|
||||
opentelemetry-instrumentation==0.53b1
|
||||
opentelemetry-instrumentation-fastapi==0.53b1
|
||||
opentelemetry-instrumentation-sqlalchemy==0.53b1
|
||||
opentelemetry-instrumentation-redis==0.53b1
|
||||
opentelemetry-instrumentation-requests==0.53b1
|
||||
opentelemetry-instrumentation-logging==0.53b1
|
||||
opentelemetry-instrumentation-httpx==0.53b1
|
||||
opentelemetry-instrumentation-aiohttp-client==0.53b1
|
||||
opentelemetry-api==1.36.0
|
||||
opentelemetry-sdk==1.36.0
|
||||
opentelemetry-exporter-otlp==1.36.0
|
||||
opentelemetry-instrumentation==0.57b0
|
||||
opentelemetry-instrumentation-fastapi==0.57b0
|
||||
opentelemetry-instrumentation-sqlalchemy==0.57b0
|
||||
opentelemetry-instrumentation-redis==0.57b0
|
||||
opentelemetry-instrumentation-requests==0.57b0
|
||||
opentelemetry-instrumentation-logging==0.57b0
|
||||
opentelemetry-instrumentation-httpx==0.57b0
|
||||
opentelemetry-instrumentation-aiohttp-client==0.57b0
|
||||
|
|
|
|||
|
|
@ -8,17 +8,28 @@ services:
|
|||
- "4318:4318" # OTLP/HTTP
|
||||
restart: unless-stopped
|
||||
|
||||
|
||||
open-webui:
|
||||
image: ghcr.io/open-webui/open-webui:main
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: ghcr.io/open-webui/open-webui:${WEBUI_DOCKER_TAG-main}
|
||||
container_name: open-webui
|
||||
depends_on: [grafana]
|
||||
volumes:
|
||||
- open-webui:/app/backend/data
|
||||
depends_on:
|
||||
- grafana
|
||||
ports:
|
||||
- ${OPEN_WEBUI_PORT-8088}:8080
|
||||
environment:
|
||||
- ENABLE_OTEL=true
|
||||
- OTEL_EXPORTER_OTLP_ENDPOINT=http://grafana:4317
|
||||
- ENABLE_OTEL_METRICS=true
|
||||
- OTEL_EXPORTER_OTLP_INSECURE=true # Use insecure connection for OTLP, remove in production
|
||||
- OTEL_EXPORTER_OTLP_ENDPOINT=http://grafana:4317
|
||||
- OTEL_SERVICE_NAME=open-webui
|
||||
ports:
|
||||
- "8088:8080"
|
||||
networks: [default]
|
||||
extra_hosts:
|
||||
- host.docker.internal:host-gateway
|
||||
restart: unless-stopped
|
||||
|
||||
networks:
|
||||
default:
|
||||
volumes:
|
||||
open-webui: {}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class CustomBuildHook(BuildHookInterface):
|
|||
"NodeJS `npm` is required for building Open Webui but it was not found"
|
||||
)
|
||||
stderr.write("### npm install\n")
|
||||
subprocess.run([npm, "install"], check=True) # noqa: S603
|
||||
subprocess.run([npm, "install", "--force"], check=True) # noqa: S603
|
||||
stderr.write("\n### npm run build\n")
|
||||
os.environ["APP_BUILD_HASH"] = version
|
||||
subprocess.run([npm, "run", "build"], check=True) # noqa: S603
|
||||
|
|
|
|||
1401
package-lock.json
generated
1401
package-lock.json
generated
File diff suppressed because it is too large
Load diff
41
package.json
41
package.json
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "open-webui",
|
||||
"version": "0.6.16",
|
||||
"version": "0.6.25",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "npm run pyodide:fetch && vite dev --host",
|
||||
|
|
@ -57,30 +57,30 @@
|
|||
"@codemirror/lang-python": "^6.1.6",
|
||||
"@codemirror/language-data": "^6.5.1",
|
||||
"@codemirror/theme-one-dark": "^6.1.2",
|
||||
"@floating-ui/dom": "^1.7.2",
|
||||
"@huggingface/transformers": "^3.0.0",
|
||||
"@joplin/turndown-plugin-gfm": "^1.0.62",
|
||||
"@mediapipe/tasks-vision": "^0.10.17",
|
||||
"@pyscript/core": "^0.4.32",
|
||||
"@sveltejs/adapter-node": "^2.0.0",
|
||||
"@sveltejs/svelte-virtual-list": "^3.0.1",
|
||||
"@tiptap/core": "^2.11.9",
|
||||
"@tiptap/extension-bubble-menu": "^2.25.0",
|
||||
"@tiptap/extension-character-count": "^2.25.0",
|
||||
"@tiptap/extension-code-block-lowlight": "^2.11.9",
|
||||
"@tiptap/extension-floating-menu": "^2.25.0",
|
||||
"@tiptap/extension-highlight": "^2.10.0",
|
||||
"@tiptap/extension-history": "^2.25.1",
|
||||
"@tiptap/extension-link": "^2.25.0",
|
||||
"@tiptap/extension-placeholder": "^2.10.0",
|
||||
"@tiptap/extension-table": "^2.12.0",
|
||||
"@tiptap/extension-table-cell": "^2.12.0",
|
||||
"@tiptap/extension-table-header": "^2.12.0",
|
||||
"@tiptap/extension-table-row": "^2.12.0",
|
||||
"@tiptap/extension-task-item": "^2.25.0",
|
||||
"@tiptap/extension-task-list": "^2.25.0",
|
||||
"@tiptap/extension-typography": "^2.10.0",
|
||||
"@tiptap/extension-underline": "^2.25.0",
|
||||
"@tiptap/pm": "^2.11.7",
|
||||
"@tiptap/starter-kit": "^2.10.0",
|
||||
"@tiptap/core": "^3.0.7",
|
||||
"@tiptap/extension-bubble-menu": "^2.26.1",
|
||||
"@tiptap/extension-code-block-lowlight": "^3.0.7",
|
||||
"@tiptap/extension-drag-handle": "^3.0.7",
|
||||
"@tiptap/extension-file-handler": "^3.0.7",
|
||||
"@tiptap/extension-floating-menu": "^2.26.1",
|
||||
"@tiptap/extension-highlight": "^3.0.7",
|
||||
"@tiptap/extension-image": "^3.0.7",
|
||||
"@tiptap/extension-link": "^3.0.7",
|
||||
"@tiptap/extension-list": "^3.0.7",
|
||||
"@tiptap/extension-mention": "^3.0.9",
|
||||
"@tiptap/extension-table": "^3.0.7",
|
||||
"@tiptap/extension-typography": "^3.0.7",
|
||||
"@tiptap/extension-youtube": "^3.0.7",
|
||||
"@tiptap/extensions": "^3.0.7",
|
||||
"@tiptap/pm": "^3.0.7",
|
||||
"@tiptap/starter-kit": "^3.0.7",
|
||||
"@xyflow/svelte": "^0.1.19",
|
||||
"async": "^3.2.5",
|
||||
"bits-ui": "^0.21.15",
|
||||
|
|
@ -108,6 +108,7 @@
|
|||
"katex": "^0.16.22",
|
||||
"kokoro-js": "^1.1.1",
|
||||
"leaflet": "^1.9.4",
|
||||
"lowlight": "^3.3.0",
|
||||
"marked": "^9.1.0",
|
||||
"mermaid": "^11.6.0",
|
||||
"paneforge": "^0.0.6",
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ authors = [
|
|||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"fastapi==0.115.7",
|
||||
"uvicorn[standard]==0.34.2",
|
||||
"uvicorn[standard]==0.35.0",
|
||||
"pydantic==2.11.7",
|
||||
"python-multipart==0.0.20",
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ dependencies = [
|
|||
"cryptography",
|
||||
|
||||
"requests==2.32.4",
|
||||
"aiohttp==3.11.11",
|
||||
"aiohttp==3.12.15",
|
||||
"async-timeout",
|
||||
"aiocache",
|
||||
"aiofiles",
|
||||
|
|
@ -35,7 +35,7 @@ dependencies = [
|
|||
|
||||
"pymongo",
|
||||
"redis",
|
||||
"boto3==1.35.53",
|
||||
"boto3==1.40.5",
|
||||
|
||||
"argon2-cffi==23.1.0",
|
||||
"APScheduler==3.10.4",
|
||||
|
|
@ -50,14 +50,14 @@ dependencies = [
|
|||
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google-genai==1.15.0",
|
||||
"google-genai==1.28.0",
|
||||
"google-generativeai==0.8.5",
|
||||
"tiktoken",
|
||||
|
||||
"langchain==0.3.26",
|
||||
"langchain-community==0.3.26",
|
||||
|
||||
"fake-useragent==2.1.0",
|
||||
"fake-useragent==2.2.0",
|
||||
"chromadb==0.6.3",
|
||||
"pymilvus==2.5.0",
|
||||
"qdrant-client==1.14.3",
|
||||
|
|
@ -65,11 +65,13 @@ dependencies = [
|
|||
"playwright==1.49.1",
|
||||
"elasticsearch==9.0.1",
|
||||
"pinecone==6.0.2",
|
||||
"oracledb==3.2.0",
|
||||
|
||||
"transformers",
|
||||
"sentence-transformers==4.1.0",
|
||||
"accelerate",
|
||||
"colbert-ai==0.2.21",
|
||||
"pyarrow==20.0.0",
|
||||
"einops==0.8.1",
|
||||
|
||||
"ftfy==6.2.3",
|
||||
|
|
@ -80,7 +82,7 @@ dependencies = [
|
|||
"python-pptx==1.0.2",
|
||||
"unstructured==0.16.17",
|
||||
"nltk==3.9.1",
|
||||
"Markdown==3.7",
|
||||
"Markdown==3.8.2",
|
||||
"pypandoc==1.15",
|
||||
"pandas==2.2.3",
|
||||
"openpyxl==3.1.5",
|
||||
|
|
@ -92,7 +94,7 @@ dependencies = [
|
|||
"soundfile==0.13.1",
|
||||
"azure-ai-documentintelligence==1.0.2",
|
||||
|
||||
"pillow==11.2.1",
|
||||
"pillow==11.3.0",
|
||||
"opencv-python-headless==4.11.0.86",
|
||||
"rapidocr-onnxruntime==1.4.4",
|
||||
"rank-bm25==0.2.2",
|
||||
|
|
@ -102,7 +104,7 @@ dependencies = [
|
|||
"faster-whisper==1.1.1",
|
||||
|
||||
"PyJWT[crypto]==2.10.1",
|
||||
"authlib==1.4.1",
|
||||
"authlib==1.6.1",
|
||||
|
||||
"black==25.1.0",
|
||||
"langfuse==2.44.0",
|
||||
|
|
@ -135,6 +137,8 @@ dependencies = [
|
|||
"gcp-storage-emulator>=2024.8.3",
|
||||
|
||||
"moto[s3]>=5.0.26",
|
||||
"oracledb>=3.2.0",
|
||||
"posthog==5.4.0",
|
||||
|
||||
]
|
||||
readme = "README.md"
|
||||
|
|
@ -191,3 +195,8 @@ skip = '.git*,*.svg,package-lock.json,i18n,*.lock,*.css,*-bundle.js,locales,exam
|
|||
check-hidden = true
|
||||
# ignore-regex = ''
|
||||
ignore-words-list = 'ans'
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest-asyncio>=1.0.0",
|
||||
]
|
||||
|
|
|
|||
26
src/app.css
26
src/app.css
|
|
@ -40,6 +40,11 @@ code {
|
|||
width: auto;
|
||||
}
|
||||
|
||||
.editor-selection {
|
||||
background: rgba(180, 213, 255, 0.5);
|
||||
border-radius: 2px;
|
||||
}
|
||||
|
||||
.font-secondary {
|
||||
font-family: 'InstrumentSerif', sans-serif;
|
||||
}
|
||||
|
|
@ -396,6 +401,17 @@ input[type='number'] {
|
|||
}
|
||||
}
|
||||
|
||||
.tiptap .mention {
|
||||
border-radius: 0.4rem;
|
||||
box-decoration-break: clone;
|
||||
padding: 0.1rem 0.3rem;
|
||||
@apply text-blue-900 dark:text-blue-100 bg-blue-300/20 dark:bg-blue-500/20;
|
||||
}
|
||||
|
||||
.tiptap .mention::after {
|
||||
content: '\200B';
|
||||
}
|
||||
|
||||
.input-prose .tiptap ul[data-type='taskList'] {
|
||||
list-style: none;
|
||||
margin-left: 0;
|
||||
|
|
@ -611,3 +627,13 @@ input[type='number'] {
|
|||
padding-right: 2px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
body {
|
||||
background: #fff;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.dark body {
|
||||
background: #171717;
|
||||
color: #eee;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@
|
|||
document.documentElement.classList.add('light');
|
||||
metaThemeColorTag.setAttribute('content', '#ffffff');
|
||||
} else if (localStorage.theme === 'her') {
|
||||
document.documentElement.classList.add('dark');
|
||||
document.documentElement.classList.add('her');
|
||||
metaThemeColorTag.setAttribute('content', '#983724');
|
||||
} else {
|
||||
|
|
@ -87,6 +86,10 @@
|
|||
|
||||
document.addEventListener('DOMContentLoaded', function () {
|
||||
const splash = document.getElementById('splash-screen');
|
||||
if (document.documentElement.classList.contains('her')) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (splash) splash.prepend(logo);
|
||||
});
|
||||
})();
|
||||
|
|
@ -168,6 +171,7 @@
|
|||
<style type="text/css" nonce="">
|
||||
html {
|
||||
overflow-y: hidden !important;
|
||||
overscroll-behavior-y: none;
|
||||
}
|
||||
|
||||
#splash-screen {
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue