mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
commit
2c3655a969
254 changed files with 12815 additions and 4214 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
||||||
|
x.py
|
||||||
|
yarn.lock
|
||||||
.DS_Store
|
.DS_Store
|
||||||
node_modules
|
node_modules
|
||||||
/build
|
/build
|
||||||
|
|
@ -12,7 +14,8 @@ vite.config.ts.timestamp-*
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
.nvmrc
|
||||||
|
CLAUDE.md
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
|
|
||||||
114
CHANGELOG.md
114
CHANGELOG.md
|
|
@ -5,6 +5,120 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [0.6.19] - 2025-08-09
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- ✨ **Modernized Sidebar and Major UI Refinements**: The main navigation sidebar has been completely redesigned with a modern, cleaner aesthetic, featuring a sticky header and footer to keep key controls accessible. Core sidebar logic, like the pinned models list, was also refactored into dedicated components for better performance and maintainability.
|
||||||
|
- 🪄 **Guided Response Regeneration**: The "Regenerate" button has been transformed into a powerful new menu. You can now guide the AI's next attempt by suggesting changes in a text prompt, or use one-click options like "Try Again," "Add Details," or "More Concise" to instantly refine and reshape the response to better fit your needs.
|
||||||
|
- 🛠️ **Improved Tool Call Handling for GPT-OSS Models**: Implemented robust handling for tool calls specifically for GPT-OSS models, ensuring proper function execution and integration.
|
||||||
|
- 🛑 **Stop Button for Merge Responses**: Added a dedicated stop button to immediately halt the generation of merged AI responses, providing users with more control over ongoing outputs.
|
||||||
|
- 🔄 **Experimental SCIM 2.0 Support**: Implemented SCIM 2.0 (System for Cross-domain Identity Management) protocol support, enabling enterprise-grade automated user and group provisioning from identity providers like Okta, Azure AD, and Google Workspace for seamless user lifecycle management. Configuration is managed securely via environment variables.
|
||||||
|
- 🗂️ **Amazon S3 Vector Support**: You can now use Amazon S3 Vector as a high-performance vector database for your Retrieval-Augmented Generation (RAG) workflows. This provides a scalable, cloud-native storage option for users deeply integrated into the AWS ecosystem, simplifying infrastructure and enabling enterprise-scale knowledge management.
|
||||||
|
- 🗄️ **Oracle 23ai Vector Search Support**: Added support for Oracle 23ai's new vector search capabilities as a supported vector database, providing a robust and scalable option for managing large-scale documents and integrating vector search with existing business data at the database level.
|
||||||
|
- ⚡ **Qdrant Performance and Configuration Enhancements**: The Qdrant client has been significantly improved with faster data retrieval logic for 'get' and 'query' operations. New environment variables ('QDRANT_TIMEOUT', 'QDRANT_HNSW_M') provide administrators with finer control over query timeouts and HNSW index parameters, enabling better performance tuning for large-scale deployments.
|
||||||
|
- 🔐 **Encrypted SQLite Database with SQLCipher**: You can now encrypt your entire SQLite database at rest using SQLCipher. By setting the 'DATABASE_TYPE' to 'sqlite+sqlcipher' and providing a 'DATABASE_PASSWORD', all data is transparently encrypted, providing an essential security layer for protecting sensitive information in self-hosted deployments. Note that this requires additional system libraries and the 'sqlcipher3-wheels' Python package.
|
||||||
|
- 🚀 **Efficient Redis Connection Management**: Implemented a shared connection pool cache to reuse Redis connections, dramatically reducing the number of active clients. This prevents connection exhaustion errors, improves performance, and ensures greater stability in high-concurrency deployments and those using Redis Sentinel.
|
||||||
|
- ⚡ **Batched Response Streaming for High Performance**: Dramatically improve performance and stability during high-speed response streaming by batching multiple tokens together before sending them to the client. A new 'Stream Delta Chunk Size' advanced parameter can be set per-model or in user/chat settings, significantly reducing CPU load on the server, Redis, and client, and preventing connection issues in high-concurrency environments.
|
||||||
|
- ⚙️ **Global Batched Streaming Configuration**: Administrators can now set a system-wide default for response streaming using the new 'CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE' environment variable. This allows for global performance tuning, while still letting per-model or per-chat settings override the default for more granular control.
|
||||||
|
- 🔎 **Advanced Chat Search with Status Filters**: Quickly find any conversation with powerful new search filters. You can now instantly narrow down your chats using prefixes like 'pinned:true', 'shared:true', and 'archived:true' directly in the search bar. An intelligent dropdown menu assists you by suggesting available filter options as you type, streamlining your workflow and making chat management more efficient than ever.
|
||||||
|
- 🛂 **Granular Chat Controls Permissions**: Administrators can now manage chat settings with greater detail. The main "Chat Controls" permission now acts as a master switch, while new granular toggles for "Valves", "System Prompts", and "Advanced Parameters" allow for more specific control over which sections are visible to users inside the panel.
|
||||||
|
- ✍️ **Formatting Toolbar for Chat Input**: Introduced a dedicated formatting toolbar for the rich text chat input field, providing users with more accessible options for text styling and editing, configurable via interface settings.
|
||||||
|
- 📑 **Tabbed View for Multi-Model Responses**: You can now enable a new tabbed interface to view responses from multiple models. Instead of side-scrolling cards, this compact view organizes each model's response into its own tab, making it easier to compare outputs and saving vertical space. This feature can be toggled on or off in Interface settings.
|
||||||
|
- ↕️ **Reorder Pinned Models via Drag-and-Drop**: You can now organize your pinned models in the sidebar by simply dragging and dropping them into your preferred order. This custom layout is saved automatically, giving you more flexible control over your workspace.
|
||||||
|
- 📌 **Quick Model Unpin Shortcut**: You can now quickly unpin a model by holding the Shift key and hovering over it to reveal an instant unpin button, streamlining your workspace customization.
|
||||||
|
- ⚡ **Improved Chat Input Performance**: The chat input is now significantly more responsive, especially when pasting or typing large amounts of text. This was achieved by implementing a debounce mechanism for the auto-save feature, which prevents UI lag and ensures a smooth, uninterrupted typing experience.
|
||||||
|
- ✍️ **Customizable Floating Quick Actions with Tool Support**: Take full control of your text interaction workflow with new customizable floating quick actions. In Settings, you can create, edit, or disable these actions and even integrate tools using the '{{TOOL:tool_id}}' syntax in your prompts, enabling powerful one-click automations on selected text. This is in addition to using placeholders like '{{CONTENT}}' and '{{INPUT_CONTENT}}' for custom text transformations.
|
||||||
|
- 🔒 **Admin Workspace Privacy Control**: Introduced the 'ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS' environment variable (defaults to 'True') allowing administrators to control their access privileges to workspace items (Knowledge, Models, Prompts, Tools). When disabled, administrators adhere to the same access control rules as regular users, enhancing data separation for multi-tenant deployments.
|
||||||
|
- 🗄️ **Comprehensive Model Configuration Management**: Administrators can now export the entire model configuration to a file and use a new declarative sync endpoint to manage models in bulk. This powerful feature enables seamless backups, migrations, and state replication across multiple instances.
|
||||||
|
- 📦 **Native Redis Cluster Mode Support**: Added full support for connecting to Redis in cluster mode, allowing for scalable and highly available Redis deployments beyond Sentinel-managed setups. New environment variables 'REDIS_CLUSTER' and 'WEBSOCKET_REDIS_CLUSTER' enable the use of 'redis.cluster.RedisCluster' clients.
|
||||||
|
- 📊 **Granular OpenTelemetry Metrics Configuration**: Introduced dedicated environment variables and enhanced configuration options for OpenTelemetry metrics, allowing for separate OTLP endpoints, basic authentication credentials, and protocol (HTTP/gRPC) specifically for metrics export, independent of trace settings. This provides greater flexibility for integrating with diverse observability stacks.
|
||||||
|
- 🪵 **Granular OpenTelemetry Logging Configuration**: Enhanced the OpenTelemetry logging integration by introducing dedicated environment variables for logs, allowing separate OTLP endpoints, basic authentication credentials, and protocol (HTTP/gRPC) specifically for log export, independent of general OTel settings. The application's default Python logger now leverages this configuration to automatically send logs to your OTel endpoint when enabled via 'ENABLE_OTEL_LOGS'.
|
||||||
|
- 📁 **Enhanced Folder Chat Management with Sorting and Time Blocks**: The chat list within folders now supports comprehensive sorting options by title and updated time, along with intelligent time-based grouping (e.g., "Today," "Yesterday") similar to the main chat view, making navigation and organization of project-specific conversations significantly easier.
|
||||||
|
- ⚙️ **Configurable Datalab Marker API & Advanced Processing Options**: Enhanced Datalab Marker API integration, allowing administrators to configure custom API base URLs for self-hosting and to specify comprehensive processing options via a new 'additional_config' JSON parameter. This replaces the deprecated language selection feature and provides granular control over document extraction, with streamlined API endpoint resolution for more robust self-hosted deployments.
|
||||||
|
- 🧑💼 **Export All Users to CSV**: Administrators can now export a complete list of all users to a CSV file directly from the Admin Panel's database settings. This provides a simple, one-click way to generate user data for auditing, reporting, or management purposes.
|
||||||
|
- 🛂 **Customizable OAuth 'sub' Claim**: Administrators can now use the 'OAUTH_SUB_CLAIM_OVERRIDE' environment variable to specify which claim from the identity provider should be used as the unique user identifier ('sub'). This provides greater flexibility and control for complex enterprise authentication setups where modifying the IDP's default claims is not possible.
|
||||||
|
- 👁️ **Password Visibility Toggle for Input Fields**: Password fields across the application (login, registration, user management, and account settings) now utilize a new 'SensitiveInput' component, providing a consistent toggle to reveal/hide passwords for improved usability and security.
|
||||||
|
- 🛂 **Optional "Confirm Password" on Sign-Up**: To help prevent password typos during account creation, administrators can now enable a "Confirm Password" field on the sign-up page. This feature is disabled by default and can be activated via an environment variable for enhanced user experience.
|
||||||
|
- 💬 **View Full Chat from User Feedback**: Administrators can now easily navigate to the full conversation associated with a user feedback entry directly from the feedback modal, streamlining the review and troubleshooting process.
|
||||||
|
- 🎚️ **Intuitive Hybrid Search BM25-Weight Slider**: The numerical input for the BM25-Weight parameter in Hybrid Search has been replaced with an interactive slider, offering a more intuitive way to adjust the balance between lexical and semantic search. A "Default/Custom" toggle and clearer labels enhance usability and understanding of this key parameter.
|
||||||
|
- ⚙️ **Enhanced Bulk Function Synchronization**: The API endpoint for synchronizing functions has been significantly improved to reliably handle bulk updates. This ensures that importing and managing large libraries of functions is more robust and error-free for administrators.
|
||||||
|
- 🖼️ **Option to Disable Image Compression in Channels**: Introduced a new setting under Interface options to allow users to force-disable image compression specifically for images posted in channels, ensuring higher resolution for critical visual content.
|
||||||
|
- 🔗 **Custom CORS Scheme Support**: Introduced a new environment variable 'CORS_ALLOW_CUSTOM_SCHEME' that allows administrators to define custom URL schemes (e.g., 'app://') for CORS origins, enabling greater flexibility for local development or desktop client integrations.
|
||||||
|
- ♿ **Translatable and Accessible Banners**: Enhanced banner elements with translatable badge text and proper ARIA attributes (aria-label, aria-hidden) for SVG icons, significantly improving accessibility and screen reader compatibility.
|
||||||
|
- ⚠️ **OAuth Configuration Warning for Missing OPENID_PROVIDER_URL**: Added a proactive startup warning that notifies administrators when OAuth providers (Google, Microsoft, or GitHub) are configured but the essential 'OPENID_PROVIDER_URL' environment variable is missing. This prevents silent OAuth logout failures and guides administrators to complete their setup correctly.
|
||||||
|
- ♿ **Major Accessibility Enhancements**: Key parts of the interface have been made significantly more accessible. The user profile menu is now fully navigable via keyboard, essential controls in the Playground now include proper ARIA labels for screen readers, and decorative images have been hidden from assistive technologies to reduce audio clutter. Menu buttons also feature enhanced accessibility with 'aria-label', 'aria-hidden' for SVGs, and 'aria-pressed' for toggle buttons.
|
||||||
|
- ⚙️ **General Backend Refactoring**: Implemented various backend improvements to enhance performance, stability, and security, ensuring a more resilient and reliable platform for all users, including refining logging output to be cleaner and more efficient by conditionally including 'extra_json' fields and improving consistent metadata handling in vector database operations, and laying preliminary scaffolding for future analytics features.
|
||||||
|
- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Catalan, Danish, Korean, Persian, Polish, Simplified Chinese, and Spanish, ensuring a more fluent and native experience for global users across all supported languages.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 🛡️ **Hardened Channel Message Security**: Fixed a key permission flaw that allowed users with channel access to edit or delete messages belonging to others. The system now correctly enforces that users can only modify their own messages, protecting data integrity in shared channels.
|
||||||
|
- 🛡️ **Hardened OAuth Security by Removing JWT from URL**: Fixed a critical security vulnerability where the authentication token was exposed in the URL after a successful OAuth login. The token is now transferred via a browser cookie, preventing potential leaks through browser history or server logs and protecting user sessions.
|
||||||
|
- 🛡️ **Hardened Chat Completion API Security**: The chat completion API endpoint now includes an explicit ownership check, ensuring non-admin users cannot access chats that do not belong to them and preventing potential unauthorized access.
|
||||||
|
- 🛠️ **Resilient Model Loading**: Fixed an issue where a failure in loading the model list (e.g., from a misconfigured provider) would prevent the entire user interface, including the admin panel, from loading. The application now gracefully handles these errors, ensuring the UI remains accessible.
|
||||||
|
- 🔒 **Resolved FIPS Self-Test Failure**: Fixed a critical issue that prevented Open WebUI from running on FIPS-compliant systems, specifically resolving the "FATAL FIPS SELFTEST FAILURE" error related to OpenSSL and SentenceTransformers, restoring compatibility with secure environments.
|
||||||
|
- 📦 **Redis Cluster Connection Restored**: Fixed an issue where the backend was unable to connect to Redis in cluster mode, now ensuring seamless integration with scalable Redis cluster deployments.
|
||||||
|
- 📦 **PGVector Connection Stability**: Fixed an issue where read-only operations could leave database transactions idle, preventing potential connection errors and improving overall database stability and resource management.
|
||||||
|
- 🛠️ **OpenAPI Tool Integration for Array Parameters Fixed**: Resolved a critical bug where external tools using array parameters (e.g., for tags) would fail when used with OpenAI models. The system now correctly generates the required 'items' property in the function schema, restoring functionality and preventing '400 Bad Request' errors.
|
||||||
|
- 🛠️ **Tool Creation for Users Restored**: Fixed a bug in the code editor where status messages were incorrectly prepended to tool scripts, causing a syntax error upon saving. All authorized users can now reliably create and save new tools.
|
||||||
|
- 📁 **Folder Knowledge Processing Restored**: Fixed a bug where files uploaded to folder and model knowledge bases were not being extracted or analyzed for Retrieval-Augmented Generation (RAG) when the 'Max Upload Count' setting was empty, ensuring seamless document processing and knowledge augmentation.
|
||||||
|
- 🧠 **Custom Model Knowledge Base Updates Recognized**: Fixed a bug where custom models linked to to knowledge bases did not automatically recognize newly added files to those knowledge bases. Models now correctly incorporate the latest information from updated knowledge collections.
|
||||||
|
- 📦 **Comprehensive Redis Key Prefixing**: Corrected hardcoded prefixes to ensure the REDIS_KEY_PREFIX is now respected across all WebSocket and task management keys. This prevents data collisions in multi-instance deployments and improves compatibility with Redis cluster mode.
|
||||||
|
- ✨ **More Descriptive OpenAI Router Errors**: The OpenAI-compatible API router now propagates detailed upstream error messages instead of returning a generic 'Bad Request'. This provides clear, actionable feedback for developers and API users, making it significantly easier to debug and resolve issues with model requests.
|
||||||
|
- 🔐 **Hardened OIDC Signout Flow**: The OpenID Connect signout process now verifies that the 'OPENID_PROVIDER_URL' is configured before attempting to communicate with it, preventing potential errors and ensuring a more reliable logout experience.
|
||||||
|
- 🍓 **Raspberry Pi Compatibility Restored**: Pinned the pyarrow library to version 20.0.0, resolving an "Illegal Instruction" crash on ARM-based devices like the Raspberry Pi and ensuring stable operation on this hardware.
|
||||||
|
- 📁 **Folder System Prompt Variables Restored**: Fixed a bug where prompt variables (e.g., '{{CURRENT_DATETIME}}') were not being rendered in Folder-level System Prompts. This restores an important capability for creating dynamic, context-aware instructions for all chats within a project folder.
|
||||||
|
- 📝 **Note Access in Knowledge Retrieval Fixed**: Corrected a permission oversight in knowledge retrieval, ensuring users can always use their own notes as a source for RAG without needing explicit sharing permissions.
|
||||||
|
- 🤖 **Title Generation Compatibility for GPT-5 Models**: Added support for 'gpt-5' models in the payload handler, which correctly converts the deprecated 'max_tokens' parameter to 'max_completion_tokens'. This resolves title generation failures and ensures seamless operation with the latest generation of models.
|
||||||
|
- ⚙️ **Correct API 'finish_reason' in Streaming Responses**: Fixed an issue where intermediate 'reasoning_content' chunks in streaming API responses incorrectly reported a 'finish_reason' of 'stop'. The 'finish_reason' is now correctly set to 'null' for these chunks, ensuring compatibility with third-party applications that rely on this field.
|
||||||
|
- 📈 **Evaluation Pages Stability**: Resolved a crash on the Leaderboard and Feedbacks pages when processing legacy feedback entries that were missing a 'rating' field. The system now gracefully handles this older data, ensuring both pages load reliably for all users.
|
||||||
|
- 🤝 **Reliable Collaborative Session Cleanup**: Fixed an asynchronous bug in the real-time collaboration engine that prevented document sessions from being properly cleaned up after all users had left. This ensures greater stability and resource management for features like Collaborative Notes.
|
||||||
|
- 🧠 **Enhanced Memory Stability and Security**: Refactored memory update and delete operations to strictly enforce user ownership, preventing potential data integrity issues. Additionally, improved error handling for memory queries now provides clearer feedback when no memories exists.
|
||||||
|
- 🧑⚖️ **Restored Admin Access to User Feedback**: Fixed a permission issue that blocked administrators from viewing or editing user feedback they didn't create, ensuring they can properly manage all evaluations across the platform.
|
||||||
|
- 🔐 **PGVector Encryption Fix for Metadata**: Corrected a SQL syntax error in the experimental 'PGVECTOR_PGCRYPTO' feature that prevented encrypted metadata from being saved. Document uploads to encrypted PGVector collections now work as intended.
|
||||||
|
- 🔍 **Serply Web Search Integration Restored**: Fixed an issue where incorrect parameters were passed to the Serply web search engine, restoring its functionality for RAG and web search workflows.
|
||||||
|
- 🔍 **Resilient Web Search Processing**: Web search retrieval now gracefully handles search results that are missing a 'snippet', preventing crashes and ensuring that RAG workflows complete successfully even with incomplete data from search engines.
|
||||||
|
- 🖼️ **Table Pasting in Rich Text Input Displayed Correctly**: Fixed an issue where pasting table text into the rich text input would incorrectly display it as code. Tables are now properly rendered as expected, improving content formatting and user experience.
|
||||||
|
- ✍️ **Rich Text Input TypeError Resolution**: Addressed a potential 'TypeError: ue.getWordAtDocPos is not a function' in 'MessageInput.svelte' by refactoring how the 'getWordAtDocPos' function is accessed and referenced from 'RichTextInput.svelte', ensuring stable rich text input behavior, especially after production restarts.
|
||||||
|
- ✏️ **Manual Code Block Creation in Chat Restored**: Fixed an issue where typing three backticks and then pressing Shift+Enter would incorrectly remove the backticks when "Enter to Send" mode was active. This ensures users can reliably create multi-line code blocks manually.
|
||||||
|
- 🎨 **Consistent Dark Mode Background**: Fixed an issue where the application background could incorrectly flash or remain white during page loads and refreshes in dark mode, ensuring a seamless and consistent visual experience.
|
||||||
|
- 🎨 **'Her' Theme Rendering Fixed**: Corrected a bug that caused the "Her" theme to incorrectly render as a dark theme in some situations. The theme now reliably applies its intended light appearance across all sessions.
|
||||||
|
- 📜 **Corrected Markdown Table Line Break Rendering**: Fixed an issue where line breaks ('<br>') within Markdown tables were displayed as raw HTML instead of being rendered correctly. This ensures that tables with multi-line cell content are now displayed as intended.
|
||||||
|
- 🚦 **Corrected App Configuration for Pending Users**: Fixed an issue where users awaiting approval could incorrectly load the full application interface, leading to a confusing or broken UI. This ensures that only fully approved users receive the standard app 'config', resulting in a smoother and more reliable onboarding experience.
|
||||||
|
- 🔄 **Chat Cloning Now Includes Tags, Folder Status, and Pinned Status**: When cloning a chat or shared chat, its associated tags, folder organization, and pinned status are now correctly replicated, ensuring consistent chat management.
|
||||||
|
- ⚙️ **Enhanced Backend Reliability**: Resolved a potential crash in knowledge base retrieval when referencing a deleted note. Additionally, chat processing was refactored to ensure model information is saved more reliably, enhancing overall system stability.
|
||||||
|
- ⚙️ **Floating 'Ask/Explain' Modal Stability**: Fixed an issue that spammed the console with errors when navigating away while a model was generating a response in the floating 'Ask' or 'Explain' modals. In-flight requests are now properly cancelled, improving application stability.
|
||||||
|
- ⚡ **Optimized User Count Checks**: Improved performance for user count and existence checks across the application by replacing resource-intensive 'COUNT' queries with more efficient 'EXISTS' queries, reducing database load.
|
||||||
|
- 🔐 **Hardened OpenTelemetry Exporter Configuration**: The OTLP HTTP exporter no longer uses a potentially insecure explicit flag, improving security by relying on the connection URL's protocol (HTTP/HTTPS) to ensure transport safety.
|
||||||
|
- 📱 **Mobile User Menu Closing Behavior Fixed**: Resolved an issue where the user menu would remain open on mobile devices after selecting an option, ensuring the menu correctly closes and returns focus to the main interface for a smoother mobile experience.
|
||||||
|
- 📱 **OnBoarding Page Display Fixed on Mobile**: Resolved an issue where buttons on the OnBoarding page were not consistently visible on certain mobile browsers, ensuring a functional and complete user experience across devices.
|
||||||
|
- ↕️ **Improved Pinned Models Drag-and-Drop Behavior**: The drag-and-drop functionality for reordering pinned models is now explicitly disabled on mobile devices, ensuring better usability and preventing potential UI conflicts or unexpected behavior.
|
||||||
|
- 📱 **PWA Rotation Behavior Corrected**: The Progressive Web App now correctly respects the device's screen orientation lock, preventing unwanted rotation and ensuring a more native mobile experience.
|
||||||
|
- ✏️ **Improved Chat Title Editing Behavior**: Changes to a chat title are now reliably saved when the user clicks away or presses Enter, replacing a less intuitive behavior that could accidentally discard edits. This makes renaming chats a smoother and more predictable experience.
|
||||||
|
- ✏️ **Underscores Allowed in Prompt Commands**: Fixed the validation for prompt commands to correctly allow the use of underscores ('\_'), aligning with documentation examples and improving flexibility in naming custom prompts.
|
||||||
|
- 💡 **Title Generation Button Behavior Fixed**: Resolved an issue where clicking the "Generate Title" button while editing a chat or note title would incorrectly save the title before generation could start. The focus is now managed correctly, ensuring a smooth and predictable user experience.
|
||||||
|
- ✏️ **Consistent Chat Input Height**: Fixed a minor visual bug where the chat input field's height would change slightly when toggling the "Rich Text Input for Chat" setting, ensuring a more stable and consistent layout.
|
||||||
|
- 🙈 **Admin UI Toggle Stability**: Fixed a visual glitch in the Admin settings where toggle switches could briefly display an incorrect state on page load, ensuring the UI always accurately reflects the saved settings.
|
||||||
|
- 🙈 **Community Sharing Button Visibility**: The "Share to Community" button on the feedback page is now correctly hidden when the Enable Community Sharing feature is disabled in the admin settings, ensuring the UI respects the configured sharing policy.
|
||||||
|
- 🙈 **"Help Us Translate" Link Visibility**: The "Help us translate" link in settings is now correctly hidden in deployments with specific license configurations, ensuring a cleaner interface for enterprise users.
|
||||||
|
- 🔗 **Robust Tool Server URL Handling**: Fixed an issue where providing a full URL for a tool server's OpenAPI specification resulted in an invalid path. The system now correctly handles both absolute URLs and relative paths, improving configuration flexibility.
|
||||||
|
- 🔧 **Improved Azure URL Detection**: The logic for identifying Azure OpenAI endpoints has been made more robust, ensuring all valid Azure URLs are now correctly detected for a smoother connection setup.
|
||||||
|
- ⚙️ **Corrected Direct Connection Save Logic**: Fixed a bug in the Admin Connections settings page by removing a redundant save action for 'Direct Connections', leading to more reliable and predictable behavior when updating settings.
|
||||||
|
- 🔗 **Corrected "Discover" Links**: The "Discover" links for models, prompts, tools, and functions now point to their specific, relevant pages on openwebui.com, improving content discovery for users.
|
||||||
|
- ⏱️ **Refined Display of AI Thought Duration**: Adjusted the display logic for AI thought (reasoning) durations to more accurately show very short thought times as "less than a second," improving clarity in AI process feedback.
|
||||||
|
- 📜 **Markdown Line Break Rendering Refinement**: Improved handling of line breaks within Markdown rendering for better visual consistency.
|
||||||
|
- 🛠️ **Corrected OpenTelemetry Docker Compose Example**: The docker-compose.otel.yaml file has been fixed and enhanced by removing duplicates, adding necessary environment variables, and hardening security settings, ensuring a more reliable out-of-box observability setup.
|
||||||
|
- 🛠️ **Development Script CORS Fix**: Corrected the CORS origin URL in the local development script (dev.sh) by removing the trailing slash, ensuring a more reliable and consistent setup for developers.
|
||||||
|
- ⬆️ **OpenTelemetry Libraries Updated**: Upgraded all OpenTelemetry-related libraries to their latest versions, ensuring better performance, stability, and compatibility for observability.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- ❗ **Docling Integration Upgraded to v1 API (Breaking Change)**: The integration with the Docling document processing engine has been updated to its new, stable '/v1' API. This is required for compatibility with Docling version 1.0.0 and newer. As a result, older versions of Docling are no longer supported. Users who rely on Docling for document ingestion **must upgrade** their docling-serve instance to ensure continued functionality.
|
||||||
|
- 🗣️ **Admin-First Whisper Language Priority**: The global WHISPER_LANGUAGE setting now acts as a strict override for audio transcriptions. If set, it will be used for all speech-to-text tasks, ignoring any language specified by the user on a per-request basis. This gives administrators more control over transcription consistency.
|
||||||
|
- ✂️ **Datalab Marker API Language Selection Removed**: The separate language selection option for the Datalab Marker API has been removed, as its functionality is now integrated and superseded by the more comprehensive 'additional_config' parameter. Users should transition to using 'additional_config' for relevant language and processing settings.
|
||||||
|
- 📄 **Documentation and Releases Links Visibility**: The "Documentation" and "Releases" links in the user menu are now visible only to admin users, streamlining the user interface for non-admin roles.
|
||||||
|
|
||||||
## [0.6.18] - 2025-07-19
|
## [0.6.18] - 2025-07-19
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,8 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
|
||||||
|
|
||||||
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
||||||
|
|
||||||
|
- 🔄 **SCIM 2.0 Support**: Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
|
||||||
|
|
||||||
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
||||||
|
|
||||||
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
export CORS_ALLOW_ORIGIN=http://localhost:5173/
|
export CORS_ALLOW_ORIGIN="http://localhost:5173"
|
||||||
PORT="${PORT:-8080}"
|
PORT="${PORT:-8080}"
|
||||||
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
|
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import redis
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generic, Optional, TypeVar
|
from typing import Generic, Union, Optional, TypeVar
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -168,7 +168,17 @@ class PersistentConfig(Generic[T]):
|
||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
self.env_value = env_value
|
self.env_value = env_value
|
||||||
self.config_value = get_config_value(config_path)
|
self.config_value = get_config_value(config_path)
|
||||||
|
|
||||||
if self.config_value is not None and ENABLE_PERSISTENT_CONFIG:
|
if self.config_value is not None and ENABLE_PERSISTENT_CONFIG:
|
||||||
|
if (
|
||||||
|
self.config_path.startswith("oauth.")
|
||||||
|
and not ENABLE_OAUTH_PERSISTENT_CONFIG
|
||||||
|
):
|
||||||
|
log.info(
|
||||||
|
f"Skipping loading of '{env_name}' as OAuth persistent config is disabled"
|
||||||
|
)
|
||||||
|
self.value = env_value
|
||||||
|
else:
|
||||||
log.info(f"'{env_name}' loaded from the latest database entry")
|
log.info(f"'{env_name}' loaded from the latest database entry")
|
||||||
self.value = self.config_value
|
self.value = self.config_value
|
||||||
else:
|
else:
|
||||||
|
|
@ -213,13 +223,14 @@ class PersistentConfig(Generic[T]):
|
||||||
|
|
||||||
class AppConfig:
|
class AppConfig:
|
||||||
_state: dict[str, PersistentConfig]
|
_state: dict[str, PersistentConfig]
|
||||||
_redis: Optional[redis.Redis] = None
|
_redis: Union[redis.Redis, redis.cluster.RedisCluster] = None
|
||||||
_redis_key_prefix: str
|
_redis_key_prefix: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_url: Optional[str] = None,
|
redis_url: Optional[str] = None,
|
||||||
redis_sentinels: Optional[list] = [],
|
redis_sentinels: Optional[list] = [],
|
||||||
|
redis_cluster: Optional[bool] = False,
|
||||||
redis_key_prefix: str = "open-webui",
|
redis_key_prefix: str = "open-webui",
|
||||||
):
|
):
|
||||||
super().__setattr__("_state", {})
|
super().__setattr__("_state", {})
|
||||||
|
|
@ -227,7 +238,12 @@ class AppConfig:
|
||||||
if redis_url:
|
if redis_url:
|
||||||
super().__setattr__(
|
super().__setattr__(
|
||||||
"_redis",
|
"_redis",
|
||||||
get_redis_connection(redis_url, redis_sentinels, decode_responses=True),
|
get_redis_connection(
|
||||||
|
redis_url,
|
||||||
|
redis_sentinels,
|
||||||
|
redis_cluster,
|
||||||
|
decode_responses=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
|
|
@ -296,6 +312,9 @@ JWT_EXPIRES_IN = PersistentConfig(
|
||||||
# OAuth config
|
# OAuth config
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
|
ENABLE_OAUTH_PERSISTENT_CONFIG = (
|
||||||
|
os.environ.get("ENABLE_OAUTH_PERSISTENT_CONFIG", "True").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
ENABLE_OAUTH_SIGNUP = PersistentConfig(
|
ENABLE_OAUTH_SIGNUP = PersistentConfig(
|
||||||
"ENABLE_OAUTH_SIGNUP",
|
"ENABLE_OAUTH_SIGNUP",
|
||||||
|
|
@ -463,6 +482,12 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
|
||||||
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
|
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
OAUTH_SUB_CLAIM = PersistentConfig(
|
||||||
|
"OAUTH_SUB_CLAIM",
|
||||||
|
"oauth.oidc.sub_claim",
|
||||||
|
os.environ.get("OAUTH_SUB_CLAIM", None),
|
||||||
|
)
|
||||||
|
|
||||||
OAUTH_USERNAME_CLAIM = PersistentConfig(
|
OAUTH_USERNAME_CLAIM = PersistentConfig(
|
||||||
"OAUTH_USERNAME_CLAIM",
|
"OAUTH_USERNAME_CLAIM",
|
||||||
"oauth.oidc.username_claim",
|
"oauth.oidc.username_claim",
|
||||||
|
|
@ -680,6 +705,23 @@ def load_oauth_providers():
|
||||||
"register": oidc_oauth_register,
|
"register": oidc_oauth_register,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
configured_providers = []
|
||||||
|
if GOOGLE_CLIENT_ID.value:
|
||||||
|
configured_providers.append("Google")
|
||||||
|
if MICROSOFT_CLIENT_ID.value:
|
||||||
|
configured_providers.append("Microsoft")
|
||||||
|
if GITHUB_CLIENT_ID.value:
|
||||||
|
configured_providers.append("GitHub")
|
||||||
|
|
||||||
|
if configured_providers and not OPENID_PROVIDER_URL.value:
|
||||||
|
provider_list = ", ".join(configured_providers)
|
||||||
|
log.warning(
|
||||||
|
f"⚠️ OAuth providers configured ({provider_list}) but OPENID_PROVIDER_URL not set - logout will not work!"
|
||||||
|
)
|
||||||
|
log.warning(
|
||||||
|
f"Set OPENID_PROVIDER_URL to your OAuth provider's OpenID Connect discovery endpoint to fix logout functionality."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
load_oauth_providers()
|
load_oauth_providers()
|
||||||
|
|
||||||
|
|
@ -1143,10 +1185,18 @@ USER_PERMISSIONS_CHAT_CONTROLS = (
|
||||||
os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true"
|
os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
USER_PERMISSIONS_CHAT_VALVES = (
|
||||||
|
os.environ.get("USER_PERMISSIONS_CHAT_VALVES", "True").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = (
|
USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = (
|
||||||
os.environ.get("USER_PERMISSIONS_CHAT_SYSTEM_PROMPT", "True").lower() == "true"
|
os.environ.get("USER_PERMISSIONS_CHAT_SYSTEM_PROMPT", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
USER_PERMISSIONS_CHAT_PARAMS = (
|
||||||
|
os.environ.get("USER_PERMISSIONS_CHAT_PARAMS", "True").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
|
USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
|
||||||
os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
|
os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -1232,7 +1282,9 @@ DEFAULT_USER_PERMISSIONS = {
|
||||||
},
|
},
|
||||||
"chat": {
|
"chat": {
|
||||||
"controls": USER_PERMISSIONS_CHAT_CONTROLS,
|
"controls": USER_PERMISSIONS_CHAT_CONTROLS,
|
||||||
|
"valves": USER_PERMISSIONS_CHAT_VALVES,
|
||||||
"system_prompt": USER_PERMISSIONS_CHAT_SYSTEM_PROMPT,
|
"system_prompt": USER_PERMISSIONS_CHAT_SYSTEM_PROMPT,
|
||||||
|
"params": USER_PERMISSIONS_CHAT_PARAMS,
|
||||||
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
|
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
|
||||||
"delete": USER_PERMISSIONS_CHAT_DELETE,
|
"delete": USER_PERMISSIONS_CHAT_DELETE,
|
||||||
"edit": USER_PERMISSIONS_CHAT_EDIT,
|
"edit": USER_PERMISSIONS_CHAT_EDIT,
|
||||||
|
|
@ -1299,6 +1351,10 @@ WEBHOOK_URL = PersistentConfig(
|
||||||
|
|
||||||
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
|
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
|
||||||
|
|
||||||
|
ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = (
|
||||||
|
os.environ.get("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS", "True").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
ENABLE_ADMIN_CHAT_ACCESS = (
|
ENABLE_ADMIN_CHAT_ACCESS = (
|
||||||
os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true"
|
os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -1337,10 +1393,11 @@ if THREAD_POOL_SIZE is not None and isinstance(THREAD_POOL_SIZE, str):
|
||||||
def validate_cors_origin(origin):
|
def validate_cors_origin(origin):
|
||||||
parsed_url = urlparse(origin)
|
parsed_url = urlparse(origin)
|
||||||
|
|
||||||
# Check if the scheme is either http or https
|
# Check if the scheme is either http or https, or a custom scheme
|
||||||
if parsed_url.scheme not in ["http", "https"]:
|
schemes = ["http", "https"] + CORS_ALLOW_CUSTOM_SCHEME
|
||||||
|
if parsed_url.scheme not in schemes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed."
|
f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' and CORS_ALLOW_CUSTOM_SCHEME are allowed."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure that the netloc (domain + port) is present, indicating it's a valid URL
|
# Ensure that the netloc (domain + port) is present, indicating it's a valid URL
|
||||||
|
|
@ -1355,6 +1412,11 @@ def validate_cors_origin(origin):
|
||||||
# in your .env file depending on your frontend port, 5173 in this case.
|
# in your .env file depending on your frontend port, 5173 in this case.
|
||||||
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
|
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
|
||||||
|
|
||||||
|
# Allows custom URL schemes (e.g., app://) to be used as origins for CORS.
|
||||||
|
# Useful for local development or desktop clients with schemes like app:// or other custom protocols.
|
||||||
|
# Provide a semicolon-separated list of allowed schemes in the environment variable CORS_ALLOW_CUSTOM_SCHEMES.
|
||||||
|
CORS_ALLOW_CUSTOM_SCHEME = os.environ.get("CORS_ALLOW_CUSTOM_SCHEME", "").split(";")
|
||||||
|
|
||||||
if CORS_ALLOW_ORIGIN == ["*"]:
|
if CORS_ALLOW_ORIGIN == ["*"]:
|
||||||
log.warning(
|
log.warning(
|
||||||
"\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
|
"\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
|
||||||
|
|
@ -1862,6 +1924,8 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
|
||||||
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
|
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
|
||||||
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
|
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
|
||||||
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
|
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
|
||||||
|
QDRANT_TIMEOUT = int(os.environ.get("QDRANT_TIMEOUT", "5"))
|
||||||
|
QDRANT_HNSW_M = int(os.environ.get("QDRANT_HNSW_M", "16"))
|
||||||
ENABLE_QDRANT_MULTITENANCY_MODE = (
|
ENABLE_QDRANT_MULTITENANCY_MODE = (
|
||||||
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
|
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -1951,6 +2015,37 @@ PINECONE_DIMENSION = int(os.getenv("PINECONE_DIMENSION", 1536)) # or 3072, 1024
|
||||||
PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine")
|
PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine")
|
||||||
PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure"
|
PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure"
|
||||||
|
|
||||||
|
# ORACLE23AI (Oracle23ai Vector Search)
|
||||||
|
|
||||||
|
ORACLE_DB_USE_WALLET = os.environ.get("ORACLE_DB_USE_WALLET", "false").lower() == "true"
|
||||||
|
ORACLE_DB_USER = os.environ.get("ORACLE_DB_USER", None) #
|
||||||
|
ORACLE_DB_PASSWORD = os.environ.get("ORACLE_DB_PASSWORD", None) #
|
||||||
|
ORACLE_DB_DSN = os.environ.get("ORACLE_DB_DSN", None) #
|
||||||
|
ORACLE_WALLET_DIR = os.environ.get("ORACLE_WALLET_DIR", None)
|
||||||
|
ORACLE_WALLET_PASSWORD = os.environ.get("ORACLE_WALLET_PASSWORD", None)
|
||||||
|
ORACLE_VECTOR_LENGTH = os.environ.get("ORACLE_VECTOR_LENGTH", 768)
|
||||||
|
|
||||||
|
ORACLE_DB_POOL_MIN = int(os.environ.get("ORACLE_DB_POOL_MIN", 2))
|
||||||
|
ORACLE_DB_POOL_MAX = int(os.environ.get("ORACLE_DB_POOL_MAX", 10))
|
||||||
|
ORACLE_DB_POOL_INCREMENT = int(os.environ.get("ORACLE_DB_POOL_INCREMENT", 1))
|
||||||
|
|
||||||
|
|
||||||
|
if VECTOR_DB == "oracle23ai":
|
||||||
|
if not ORACLE_DB_USER or not ORACLE_DB_PASSWORD or not ORACLE_DB_DSN:
|
||||||
|
raise ValueError(
|
||||||
|
"Oracle23ai requires setting ORACLE_DB_USER, ORACLE_DB_PASSWORD, and ORACLE_DB_DSN."
|
||||||
|
)
|
||||||
|
if ORACLE_DB_USE_WALLET and (not ORACLE_WALLET_DIR or not ORACLE_WALLET_PASSWORD):
|
||||||
|
raise ValueError(
|
||||||
|
"Oracle23ai requires setting ORACLE_WALLET_DIR and ORACLE_WALLET_PASSWORD when using wallet authentication."
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"VECTOR_DB: {VECTOR_DB}")
|
||||||
|
|
||||||
|
# S3 Vector
|
||||||
|
S3_VECTOR_BUCKET_NAME = os.environ.get("S3_VECTOR_BUCKET_NAME", None)
|
||||||
|
S3_VECTOR_REGION = os.environ.get("S3_VECTOR_REGION", None)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Information Retrieval (RAG)
|
# Information Retrieval (RAG)
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -2012,10 +2107,16 @@ DATALAB_MARKER_API_KEY = PersistentConfig(
|
||||||
os.environ.get("DATALAB_MARKER_API_KEY", ""),
|
os.environ.get("DATALAB_MARKER_API_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
DATALAB_MARKER_LANGS = PersistentConfig(
|
DATALAB_MARKER_API_BASE_URL = PersistentConfig(
|
||||||
"DATALAB_MARKER_LANGS",
|
"DATALAB_MARKER_API_BASE_URL",
|
||||||
"rag.datalab_marker_langs",
|
"rag.datalab_marker_api_base_url",
|
||||||
os.environ.get("DATALAB_MARKER_LANGS", ""),
|
os.environ.get("DATALAB_MARKER_API_BASE_URL", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
DATALAB_MARKER_ADDITIONAL_CONFIG = PersistentConfig(
|
||||||
|
"DATALAB_MARKER_ADDITIONAL_CONFIG",
|
||||||
|
"rag.datalab_marker_additional_config",
|
||||||
|
os.environ.get("DATALAB_MARKER_ADDITIONAL_CONFIG", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
DATALAB_MARKER_USE_LLM = PersistentConfig(
|
DATALAB_MARKER_USE_LLM = PersistentConfig(
|
||||||
|
|
@ -2055,6 +2156,12 @@ DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = PersistentConfig(
|
||||||
== "true",
|
== "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DATALAB_MARKER_FORMAT_LINES = PersistentConfig(
|
||||||
|
"DATALAB_MARKER_FORMAT_LINES",
|
||||||
|
"rag.datalab_marker_format_lines",
|
||||||
|
os.environ.get("DATALAB_MARKER_FORMAT_LINES", "false").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
DATALAB_MARKER_OUTPUT_FORMAT = PersistentConfig(
|
DATALAB_MARKER_OUTPUT_FORMAT = PersistentConfig(
|
||||||
"DATALAB_MARKER_OUTPUT_FORMAT",
|
"DATALAB_MARKER_OUTPUT_FORMAT",
|
||||||
"rag.datalab_marker_output_format",
|
"rag.datalab_marker_output_format",
|
||||||
|
|
|
||||||
|
|
@ -288,6 +288,9 @@ DB_VARS = {
|
||||||
|
|
||||||
if all(DB_VARS.values()):
|
if all(DB_VARS.values()):
|
||||||
DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
|
DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
|
||||||
|
elif DATABASE_TYPE == "sqlite+sqlcipher" and not os.environ.get("DATABASE_URL"):
|
||||||
|
# Handle SQLCipher with local file when DATABASE_URL wasn't explicitly set
|
||||||
|
DATABASE_URL = f"sqlite+sqlcipher:///{DATA_DIR}/webui.db"
|
||||||
|
|
||||||
# Replace the postgres:// with postgresql://
|
# Replace the postgres:// with postgresql://
|
||||||
if "postgres://" in DATABASE_URL:
|
if "postgres://" in DATABASE_URL:
|
||||||
|
|
@ -346,7 +349,10 @@ ENABLE_REALTIME_CHAT_SAVE = (
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
REDIS_URL = os.environ.get("REDIS_URL", "")
|
REDIS_URL = os.environ.get("REDIS_URL", "")
|
||||||
|
REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
|
||||||
|
|
||||||
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
|
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
|
||||||
|
|
||||||
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
|
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
|
||||||
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
|
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
|
||||||
|
|
||||||
|
|
@ -378,6 +384,10 @@ except ValueError:
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
||||||
|
ENABLE_SIGNUP_PASSWORD_CONFIRMATION = (
|
||||||
|
os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
||||||
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
||||||
)
|
)
|
||||||
|
|
@ -432,6 +442,13 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# SCIM Configuration
|
||||||
|
####################################
|
||||||
|
|
||||||
|
SCIM_ENABLED = os.environ.get("SCIM_ENABLED", "False").lower() == "true"
|
||||||
|
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# LICENSE_KEY
|
# LICENSE_KEY
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -473,6 +490,25 @@ else:
|
||||||
MODELS_CACHE_TTL = 1
|
MODELS_CACHE_TTL = 1
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# CHAT
|
||||||
|
####################################
|
||||||
|
|
||||||
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
|
||||||
|
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
|
||||||
|
)
|
||||||
|
|
||||||
|
if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == "":
|
||||||
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int(
|
||||||
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBSOCKET SUPPORT
|
# WEBSOCKET SUPPORT
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -485,6 +521,9 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
||||||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||||
|
|
||||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||||
|
WEBSOCKET_REDIS_CLUSTER = (
|
||||||
|
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
|
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
|
||||||
|
|
||||||
|
|
@ -494,9 +533,9 @@ except ValueError:
|
||||||
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
|
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
|
||||||
|
|
||||||
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
||||||
|
|
||||||
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
||||||
|
|
||||||
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||||
|
|
||||||
if AIOHTTP_CLIENT_TIMEOUT == "":
|
if AIOHTTP_CLIENT_TIMEOUT == "":
|
||||||
|
|
@ -639,12 +678,26 @@ AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||||
|
|
||||||
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
|
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
|
||||||
ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true"
|
ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true"
|
||||||
|
ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true"
|
||||||
|
|
||||||
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||||
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
|
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
|
||||||
)
|
)
|
||||||
|
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||||
|
"OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
|
||||||
|
)
|
||||||
|
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
|
||||||
|
"OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
|
||||||
|
)
|
||||||
OTEL_EXPORTER_OTLP_INSECURE = (
|
OTEL_EXPORTER_OTLP_INSECURE = (
|
||||||
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
|
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
OTEL_METRICS_EXPORTER_OTLP_INSECURE = (
|
||||||
|
os.environ.get("OTEL_METRICS_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
OTEL_LOGS_EXPORTER_OTLP_INSECURE = (
|
||||||
|
os.environ.get("OTEL_LOGS_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
|
||||||
|
)
|
||||||
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
|
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
|
||||||
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
|
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
|
||||||
"OTEL_RESOURCE_ATTRIBUTES", ""
|
"OTEL_RESOURCE_ATTRIBUTES", ""
|
||||||
|
|
@ -655,11 +708,30 @@ OTEL_TRACES_SAMPLER = os.environ.get(
|
||||||
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
|
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
|
||||||
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
|
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
|
||||||
|
|
||||||
|
OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get(
|
||||||
|
"OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
|
||||||
|
)
|
||||||
|
OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get(
|
||||||
|
"OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
|
||||||
|
)
|
||||||
|
OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get(
|
||||||
|
"OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
|
||||||
|
)
|
||||||
|
OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get(
|
||||||
|
"OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
|
||||||
|
)
|
||||||
|
|
||||||
OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
|
OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
|
||||||
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
|
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
|
||||||
).lower() # grpc or http
|
).lower() # grpc or http
|
||||||
|
|
||||||
|
OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get(
|
||||||
|
"OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
|
||||||
|
).lower() # grpc or http
|
||||||
|
|
||||||
|
OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get(
|
||||||
|
"OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
|
||||||
|
).lower() # grpc or http
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# TOOLS/FUNCTIONS PIP OPTIONS
|
# TOOLS/FUNCTIONS PIP OPTIONS
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
@ -79,7 +80,37 @@ handle_peewee_migration(DATABASE_URL)
|
||||||
|
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||||
if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
|
||||||
|
# Handle SQLCipher URLs
|
||||||
|
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||||
|
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||||
|
if not database_password or database_password.strip() == "":
|
||||||
|
raise ValueError(
|
||||||
|
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract database path from SQLCipher URL
|
||||||
|
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
||||||
|
if db_path.startswith("/"):
|
||||||
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
|
# Create a custom creator function that uses sqlcipher3
|
||||||
|
def create_sqlcipher_connection():
|
||||||
|
import sqlcipher3
|
||||||
|
|
||||||
|
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
||||||
|
conn.execute(f"PRAGMA key = '{database_password}'")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite://", # Dummy URL since we're using creator
|
||||||
|
creator=create_sqlcipher_connection,
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||||
|
|
||||||
|
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
@ -43,6 +44,29 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||||
|
|
||||||
|
|
||||||
def register_connection(db_url):
|
def register_connection(db_url):
|
||||||
|
# Check if using SQLCipher protocol
|
||||||
|
if db_url.startswith("sqlite+sqlcipher://"):
|
||||||
|
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||||
|
if not database_password or database_password.strip() == "":
|
||||||
|
raise ValueError(
|
||||||
|
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||||
|
)
|
||||||
|
from playhouse.sqlcipher_ext import SqlCipherDatabase
|
||||||
|
|
||||||
|
# Parse the database path from SQLCipher URL
|
||||||
|
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
||||||
|
db_path = db_url.replace("sqlite+sqlcipher://", "")
|
||||||
|
if db_path.startswith("/"):
|
||||||
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
|
# Use Peewee's native SqlCipherDatabase with encryption
|
||||||
|
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
||||||
|
db.autoconnect = True
|
||||||
|
db.reuse_if_open = True
|
||||||
|
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Standard database connection (existing logic)
|
||||||
db = connect(db_url, unquote_user=True, unquote_password=True)
|
db = connect(db_url, unquote_user=True, unquote_password=True)
|
||||||
if isinstance(db, PostgresqlDatabase):
|
if isinstance(db, PostgresqlDatabase):
|
||||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,7 @@ from open_webui.routers import (
|
||||||
tools,
|
tools,
|
||||||
users,
|
users,
|
||||||
utils,
|
utils,
|
||||||
|
scim,
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.routers.retrieval import (
|
from open_webui.routers.retrieval import (
|
||||||
|
|
@ -226,12 +227,14 @@ from open_webui.config import (
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
CONTENT_EXTRACTION_ENGINE,
|
CONTENT_EXTRACTION_ENGINE,
|
||||||
DATALAB_MARKER_API_KEY,
|
DATALAB_MARKER_API_KEY,
|
||||||
DATALAB_MARKER_LANGS,
|
DATALAB_MARKER_API_BASE_URL,
|
||||||
|
DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||||
DATALAB_MARKER_SKIP_CACHE,
|
DATALAB_MARKER_SKIP_CACHE,
|
||||||
DATALAB_MARKER_FORCE_OCR,
|
DATALAB_MARKER_FORCE_OCR,
|
||||||
DATALAB_MARKER_PAGINATE,
|
DATALAB_MARKER_PAGINATE,
|
||||||
DATALAB_MARKER_STRIP_EXISTING_OCR,
|
DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||||
|
DATALAB_MARKER_FORMAT_LINES,
|
||||||
DATALAB_MARKER_OUTPUT_FORMAT,
|
DATALAB_MARKER_OUTPUT_FORMAT,
|
||||||
DATALAB_MARKER_USE_LLM,
|
DATALAB_MARKER_USE_LLM,
|
||||||
EXTERNAL_DOCUMENT_LOADER_URL,
|
EXTERNAL_DOCUMENT_LOADER_URL,
|
||||||
|
|
@ -399,6 +402,7 @@ from open_webui.env import (
|
||||||
AUDIT_LOG_LEVEL,
|
AUDIT_LOG_LEVEL,
|
||||||
CHANGELOG,
|
CHANGELOG,
|
||||||
REDIS_URL,
|
REDIS_URL,
|
||||||
|
REDIS_CLUSTER,
|
||||||
REDIS_KEY_PREFIX,
|
REDIS_KEY_PREFIX,
|
||||||
REDIS_SENTINEL_HOSTS,
|
REDIS_SENTINEL_HOSTS,
|
||||||
REDIS_SENTINEL_PORT,
|
REDIS_SENTINEL_PORT,
|
||||||
|
|
@ -412,9 +416,13 @@ from open_webui.env import (
|
||||||
WEBUI_SECRET_KEY,
|
WEBUI_SECRET_KEY,
|
||||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||||
WEBUI_SESSION_COOKIE_SECURE,
|
WEBUI_SESSION_COOKIE_SECURE,
|
||||||
|
ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||||
|
# SCIM
|
||||||
|
SCIM_ENABLED,
|
||||||
|
SCIM_TOKEN,
|
||||||
ENABLE_COMPRESSION_MIDDLEWARE,
|
ENABLE_COMPRESSION_MIDDLEWARE,
|
||||||
ENABLE_WEBSOCKET_SUPPORT,
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
BYPASS_MODEL_ACCESS_CONTROL,
|
BYPASS_MODEL_ACCESS_CONTROL,
|
||||||
|
|
@ -462,6 +470,9 @@ from open_webui.tasks import (
|
||||||
from open_webui.utils.redis import get_sentinels_from_env
|
from open_webui.utils.redis import get_sentinels_from_env
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
|
||||||
if SAFE_MODE:
|
if SAFE_MODE:
|
||||||
print("SAFE MODE ENABLED")
|
print("SAFE MODE ENABLED")
|
||||||
Functions.deactivate_all_functions()
|
Functions.deactivate_all_functions()
|
||||||
|
|
@ -524,6 +535,7 @@ async def lifespan(app: FastAPI):
|
||||||
redis_sentinels=get_sentinels_from_env(
|
redis_sentinels=get_sentinels_from_env(
|
||||||
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||||
),
|
),
|
||||||
|
redis_cluster=REDIS_CLUSTER,
|
||||||
async_mode=True,
|
async_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -579,6 +591,7 @@ app.state.instance_id = None
|
||||||
app.state.config = AppConfig(
|
app.state.config = AppConfig(
|
||||||
redis_url=REDIS_URL,
|
redis_url=REDIS_URL,
|
||||||
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
|
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
|
||||||
|
redis_cluster=REDIS_CLUSTER,
|
||||||
redis_key_prefix=REDIS_KEY_PREFIX,
|
redis_key_prefix=REDIS_KEY_PREFIX,
|
||||||
)
|
)
|
||||||
app.state.redis = None
|
app.state.redis = None
|
||||||
|
|
@ -642,6 +655,15 @@ app.state.TOOL_SERVERS = []
|
||||||
|
|
||||||
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||||
|
|
||||||
|
########################################
|
||||||
|
#
|
||||||
|
# SCIM
|
||||||
|
#
|
||||||
|
########################################
|
||||||
|
|
||||||
|
app.state.SCIM_ENABLED = SCIM_ENABLED
|
||||||
|
app.state.SCIM_TOKEN = SCIM_TOKEN
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
#
|
#
|
||||||
# MODELS
|
# MODELS
|
||||||
|
|
@ -767,7 +789,8 @@ app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERI
|
||||||
|
|
||||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||||
app.state.config.DATALAB_MARKER_API_KEY = DATALAB_MARKER_API_KEY
|
app.state.config.DATALAB_MARKER_API_KEY = DATALAB_MARKER_API_KEY
|
||||||
app.state.config.DATALAB_MARKER_LANGS = DATALAB_MARKER_LANGS
|
app.state.config.DATALAB_MARKER_API_BASE_URL = DATALAB_MARKER_API_BASE_URL
|
||||||
|
app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG = DATALAB_MARKER_ADDITIONAL_CONFIG
|
||||||
app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE
|
app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE
|
||||||
app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR
|
app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR
|
||||||
app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE
|
app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE
|
||||||
|
|
@ -775,6 +798,7 @@ app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTI
|
||||||
app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
|
app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
|
||||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||||
)
|
)
|
||||||
|
app.state.config.DATALAB_MARKER_FORMAT_LINES = DATALAB_MARKER_FORMAT_LINES
|
||||||
app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM
|
app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM
|
||||||
app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT
|
app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT
|
||||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
|
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
|
||||||
|
|
@ -1211,6 +1235,10 @@ app.include_router(
|
||||||
)
|
)
|
||||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||||
|
|
||||||
|
# SCIM 2.0 API for identity management
|
||||||
|
if SCIM_ENABLED:
|
||||||
|
app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"])
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
|
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
|
||||||
|
|
@ -1296,7 +1324,7 @@ async def get_models(
|
||||||
models = get_filtered_models(models, user)
|
models = get_filtered_models(models, user)
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}"
|
f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}"
|
||||||
)
|
)
|
||||||
return {"data": models}
|
return {"data": models}
|
||||||
|
|
||||||
|
|
@ -1373,6 +1401,19 @@ async def chat_completion(
|
||||||
request.state.direct = True
|
request.state.direct = True
|
||||||
request.state.model = model
|
request.state.model = model
|
||||||
|
|
||||||
|
model_info_params = (
|
||||||
|
model_info.params.model_dump() if model_info and model_info.params else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chat Params
|
||||||
|
stream_delta_chunk_size = form_data.get("params", {}).get(
|
||||||
|
"stream_delta_chunk_size"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model Params
|
||||||
|
if model_info_params.get("stream_delta_chunk_size"):
|
||||||
|
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
"chat_id": form_data.pop("chat_id", None),
|
"chat_id": form_data.pop("chat_id", None),
|
||||||
|
|
@ -1386,25 +1427,33 @@ async def chat_completion(
|
||||||
"variables": form_data.get("variables", {}),
|
"variables": form_data.get("variables", {}),
|
||||||
"model": model,
|
"model": model,
|
||||||
"direct": model_item.get("direct", False),
|
"direct": model_item.get("direct", False),
|
||||||
**(
|
"params": {
|
||||||
{"function_calling": "native"}
|
"stream_delta_chunk_size": stream_delta_chunk_size,
|
||||||
if form_data.get("params", {}).get("function_calling") == "native"
|
"function_calling": (
|
||||||
or (
|
"native"
|
||||||
model_info
|
if (
|
||||||
and model_info.params.model_dump().get("function_calling")
|
form_data.get("params", {}).get("function_calling") == "native"
|
||||||
== "native"
|
or model_info_params.get("function_calling") == "native"
|
||||||
)
|
)
|
||||||
else {}
|
else "default"
|
||||||
),
|
),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if metadata.get("chat_id") and (user and user.role != "admin"):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
||||||
|
if chat is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
|
)
|
||||||
|
|
||||||
request.state.metadata = metadata
|
request.state.metadata = metadata
|
||||||
form_data["metadata"] = metadata
|
form_data["metadata"] = metadata
|
||||||
|
|
||||||
form_data, metadata, events = await process_chat_payload(
|
form_data, metadata, events = await process_chat_payload(
|
||||||
request, form_data, user, metadata, model
|
request, form_data, user, metadata, model
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error processing chat payload: {e}")
|
log.debug(f"Error processing chat payload: {e}")
|
||||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||||
|
|
@ -1424,6 +1473,14 @@ async def chat_completion(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await chat_completion_handler(request, form_data, user)
|
response = await chat_completion_handler(request, form_data, user)
|
||||||
|
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||||
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||||
|
metadata["chat_id"],
|
||||||
|
metadata["message_id"],
|
||||||
|
{
|
||||||
|
"model": model_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return await process_chat_response(
|
return await process_chat_response(
|
||||||
request, response, form_data, user, metadata, model, events, tasks
|
request, response, form_data, user, metadata, model, events, tasks
|
||||||
|
|
@ -1563,6 +1620,7 @@ async def get_app_config(request: Request):
|
||||||
"features": {
|
"features": {
|
||||||
"auth": WEBUI_AUTH,
|
"auth": WEBUI_AUTH,
|
||||||
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||||
|
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
||||||
"enable_ldap": app.state.config.ENABLE_LDAP,
|
"enable_ldap": app.state.config.ENABLE_LDAP,
|
||||||
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
||||||
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
||||||
|
|
@ -1641,14 +1699,17 @@ async def get_app_config(request: Request):
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if user is not None
|
if user is not None and (user.role in ["admin", "user"])
|
||||||
else {
|
else {
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"login_footer": app.state.LICENSE_METADATA.get(
|
"login_footer": app.state.LICENSE_METADATA.get(
|
||||||
"login_footer", ""
|
"login_footer", ""
|
||||||
)
|
),
|
||||||
|
"auth_logo_position": app.state.LICENSE_METADATA.get(
|
||||||
|
"auth_logo_position", ""
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if app.state.LICENSE_METADATA
|
if app.state.LICENSE_METADATA
|
||||||
|
|
@ -1765,11 +1826,10 @@ async def get_manifest_json():
|
||||||
return {
|
return {
|
||||||
"name": app.state.WEBUI_NAME,
|
"name": app.state.WEBUI_NAME,
|
||||||
"short_name": app.state.WEBUI_NAME,
|
"short_name": app.state.WEBUI_NAME,
|
||||||
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
|
"description": f"{app.state.WEBUI_NAME} is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
|
||||||
"start_url": "/",
|
"start_url": "/",
|
||||||
"display": "standalone",
|
"display": "standalone",
|
||||||
"background_color": "#343541",
|
"background_color": "#343541",
|
||||||
"orientation": "any",
|
|
||||||
"icons": [
|
"icons": [
|
||||||
{
|
{
|
||||||
"src": "/static/logo.png",
|
"src": "/static/logo.png",
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ from logging.config import fileConfig
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from open_webui.models.auths import Auth
|
from open_webui.models.auths import Auth
|
||||||
from open_webui.env import DATABASE_URL
|
from open_webui.env import DATABASE_URL, DATABASE_PASSWORD
|
||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import engine_from_config, pool, create_engine
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
|
|
@ -62,6 +62,33 @@ def run_migrations_online() -> None:
|
||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# Handle SQLCipher URLs
|
||||||
|
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
|
||||||
|
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
|
||||||
|
raise ValueError(
|
||||||
|
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract database path from SQLCipher URL
|
||||||
|
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
|
||||||
|
if db_path.startswith("/"):
|
||||||
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
|
# Create a custom creator function that uses sqlcipher3
|
||||||
|
def create_sqlcipher_connection():
|
||||||
|
import sqlcipher3
|
||||||
|
|
||||||
|
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
||||||
|
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
connectable = create_engine(
|
||||||
|
"sqlite://", # Dummy URL since we're using creator
|
||||||
|
creator=create_sqlcipher_connection,
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard database connection (existing logic)
|
||||||
connectable = engine_from_config(
|
connectable = engine_from_config(
|
||||||
config.get_section(config.config_ini_section, {}),
|
config.get_section(config.config_ini_section, {}),
|
||||||
prefix="sqlalchemy.",
|
prefix="sqlalchemy.",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.tags import TagModel, Tag, Tags
|
from open_webui.models.tags import TagModel, Tag, Tags
|
||||||
|
from open_webui.models.folders import Folders
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
@ -296,6 +297,9 @@ class ChatTable:
|
||||||
"user_id": f"shared-{chat_id}",
|
"user_id": f"shared-{chat_id}",
|
||||||
"title": chat.title,
|
"title": chat.title,
|
||||||
"chat": chat.chat,
|
"chat": chat.chat,
|
||||||
|
"meta": chat.meta,
|
||||||
|
"pinned": chat.pinned,
|
||||||
|
"folder_id": chat.folder_id,
|
||||||
"created_at": chat.created_at,
|
"created_at": chat.created_at,
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
}
|
}
|
||||||
|
|
@ -327,7 +331,9 @@ class ChatTable:
|
||||||
|
|
||||||
shared_chat.title = chat.title
|
shared_chat.title = chat.title
|
||||||
shared_chat.chat = chat.chat
|
shared_chat.chat = chat.chat
|
||||||
|
shared_chat.meta = chat.meta
|
||||||
|
shared_chat.pinned = chat.pinned
|
||||||
|
shared_chat.folder_id = chat.folder_id
|
||||||
shared_chat.updated_at = int(time.time())
|
shared_chat.updated_at = int(time.time())
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(shared_chat)
|
db.refresh(shared_chat)
|
||||||
|
|
@ -612,8 +618,45 @@ class ChatTable:
|
||||||
if word.startswith("tag:")
|
if word.startswith("tag:")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Extract folder names - handle spaces and case insensitivity
|
||||||
|
folders = Folders.search_folders_by_names(
|
||||||
|
user_id,
|
||||||
|
[
|
||||||
|
word.replace("folder:", "")
|
||||||
|
for word in search_text_words
|
||||||
|
if word.startswith("folder:")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
folder_ids = [folder.id for folder in folders]
|
||||||
|
|
||||||
|
is_pinned = None
|
||||||
|
if "pinned:true" in search_text_words:
|
||||||
|
is_pinned = True
|
||||||
|
elif "pinned:false" in search_text_words:
|
||||||
|
is_pinned = False
|
||||||
|
|
||||||
|
is_archived = None
|
||||||
|
if "archived:true" in search_text_words:
|
||||||
|
is_archived = True
|
||||||
|
elif "archived:false" in search_text_words:
|
||||||
|
is_archived = False
|
||||||
|
|
||||||
|
is_shared = None
|
||||||
|
if "shared:true" in search_text_words:
|
||||||
|
is_shared = True
|
||||||
|
elif "shared:false" in search_text_words:
|
||||||
|
is_shared = False
|
||||||
|
|
||||||
search_text_words = [
|
search_text_words = [
|
||||||
word for word in search_text_words if not word.startswith("tag:")
|
word
|
||||||
|
for word in search_text_words
|
||||||
|
if (
|
||||||
|
not word.startswith("tag:")
|
||||||
|
and not word.startswith("folder:")
|
||||||
|
and not word.startswith("pinned:")
|
||||||
|
and not word.startswith("archived:")
|
||||||
|
and not word.startswith("shared:")
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
search_text = " ".join(search_text_words)
|
search_text = " ".join(search_text_words)
|
||||||
|
|
@ -621,9 +664,23 @@ class ChatTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
query = db.query(Chat).filter(Chat.user_id == user_id)
|
query = db.query(Chat).filter(Chat.user_id == user_id)
|
||||||
|
|
||||||
if not include_archived:
|
if is_archived is not None:
|
||||||
|
query = query.filter(Chat.archived == is_archived)
|
||||||
|
elif not include_archived:
|
||||||
query = query.filter(Chat.archived == False)
|
query = query.filter(Chat.archived == False)
|
||||||
|
|
||||||
|
if is_pinned is not None:
|
||||||
|
query = query.filter(Chat.pinned == is_pinned)
|
||||||
|
|
||||||
|
if is_shared is not None:
|
||||||
|
if is_shared:
|
||||||
|
query = query.filter(Chat.share_id.isnot(None))
|
||||||
|
else:
|
||||||
|
query = query.filter(Chat.share_id.is_(None))
|
||||||
|
|
||||||
|
if folder_ids:
|
||||||
|
query = query.filter(Chat.folder_id.in_(folder_ids))
|
||||||
|
|
||||||
query = query.order_by(Chat.updated_at.desc())
|
query = query.order_by(Chat.updated_at.desc())
|
||||||
|
|
||||||
# Check if the database dialect is either 'sqlite' or 'postgresql'
|
# Check if the database dialect is either 'sqlite' or 'postgresql'
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,14 @@ import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.chats import Chats
|
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
|
||||||
from open_webui.utils.access_control import get_permissions
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -106,7 +106,7 @@ class FolderTable:
|
||||||
|
|
||||||
def get_children_folders_by_id_and_user_id(
|
def get_children_folders_by_id_and_user_id(
|
||||||
self, id: str, user_id: str
|
self, id: str, user_id: str
|
||||||
) -> Optional[FolderModel]:
|
) -> Optional[list[FolderModel]]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
folders = []
|
folders = []
|
||||||
|
|
@ -251,18 +251,15 @@ class FolderTable:
|
||||||
log.error(f"update_folder: {e}")
|
log.error(f"update_folder: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
def delete_folder_by_id_and_user_id(
|
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]:
|
||||||
self, id: str, user_id: str, delete_chats=True
|
|
||||||
) -> bool:
|
|
||||||
try:
|
try:
|
||||||
|
folder_ids = []
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||||
if not folder:
|
if not folder:
|
||||||
return False
|
return folder_ids
|
||||||
|
|
||||||
if delete_chats:
|
folder_ids.append(folder.id)
|
||||||
# Delete all chats in the folder
|
|
||||||
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
|
|
||||||
|
|
||||||
# Delete all children folders
|
# Delete all children folders
|
||||||
def delete_children(folder):
|
def delete_children(folder):
|
||||||
|
|
@ -270,12 +267,9 @@ class FolderTable:
|
||||||
folder.id, user_id
|
folder.id, user_id
|
||||||
)
|
)
|
||||||
for folder_child in folder_children:
|
for folder_child in folder_children:
|
||||||
if delete_chats:
|
|
||||||
Chats.delete_chats_by_user_id_and_folder_id(
|
|
||||||
user_id, folder_child.id
|
|
||||||
)
|
|
||||||
|
|
||||||
delete_children(folder_child)
|
delete_children(folder_child)
|
||||||
|
folder_ids.append(folder_child.id)
|
||||||
|
|
||||||
folder = db.query(Folder).filter_by(id=folder_child.id).first()
|
folder = db.query(Folder).filter_by(id=folder_child.id).first()
|
||||||
db.delete(folder)
|
db.delete(folder)
|
||||||
|
|
@ -284,10 +278,62 @@ class FolderTable:
|
||||||
delete_children(folder)
|
delete_children(folder)
|
||||||
db.delete(folder)
|
db.delete(folder)
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return folder_ids
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"delete_folder: {e}")
|
log.error(f"delete_folder: {e}")
|
||||||
return False
|
return []
|
||||||
|
|
||||||
|
def normalize_folder_name(self, name: str) -> str:
|
||||||
|
# Replace _ and space with a single space, lower case, collapse multiple spaces
|
||||||
|
name = re.sub(r"[\s_]+", " ", name)
|
||||||
|
return name.strip().lower()
|
||||||
|
|
||||||
|
def search_folders_by_names(
|
||||||
|
self, user_id: str, queries: list[str]
|
||||||
|
) -> list[FolderModel]:
|
||||||
|
"""
|
||||||
|
Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive.
|
||||||
|
"""
|
||||||
|
normalized_queries = [self.normalize_folder_name(q) for q in queries]
|
||||||
|
if not normalized_queries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
with get_db() as db:
|
||||||
|
folders = db.query(Folder).filter_by(user_id=user_id).all()
|
||||||
|
for folder in folders:
|
||||||
|
if self.normalize_folder_name(folder.name) in normalized_queries:
|
||||||
|
results[folder.id] = FolderModel.model_validate(folder)
|
||||||
|
|
||||||
|
# get children folders
|
||||||
|
children = self.get_children_folders_by_id_and_user_id(
|
||||||
|
folder.id, user_id
|
||||||
|
)
|
||||||
|
for child in children:
|
||||||
|
results[child.id] = child
|
||||||
|
|
||||||
|
# Return the results as a list
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
results = list(results.values())
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_folders_by_name_contains(
|
||||||
|
self, user_id: str, query: str
|
||||||
|
) -> list[FolderModel]:
|
||||||
|
"""
|
||||||
|
Partial match: normalized name contains (as substring) the normalized query.
|
||||||
|
"""
|
||||||
|
normalized_query = self.normalize_folder_name(query)
|
||||||
|
results = []
|
||||||
|
with get_db() as db:
|
||||||
|
folders = db.query(Folder).filter_by(user_id=user_id).all()
|
||||||
|
for folder in folders:
|
||||||
|
norm_name = self.normalize_folder_name(folder.name)
|
||||||
|
if normalized_query in norm_name:
|
||||||
|
results.append(FolderModel.model_validate(folder))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
Folders = FolderTable()
|
Folders = FolderTable()
|
||||||
|
|
|
||||||
|
|
@ -71,9 +71,13 @@ class MemoriesTable:
|
||||||
) -> Optional[MemoryModel]:
|
) -> Optional[MemoryModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
db.query(Memory).filter_by(id=id, user_id=user_id).update(
|
memory = db.get(Memory, id)
|
||||||
{"content": content, "updated_at": int(time.time())}
|
if not memory or memory.user_id != user_id:
|
||||||
)
|
return None
|
||||||
|
|
||||||
|
memory.content = content
|
||||||
|
memory.updated_at = int(time.time())
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return self.get_memory_by_id(id)
|
return self.get_memory_by_id(id)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -127,7 +131,12 @@ class MemoriesTable:
|
||||||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
|
memory = db.get(Memory, id)
|
||||||
|
if not memory or memory.user_id != user_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Delete the memory
|
||||||
|
db.delete(memory)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -269,5 +269,49 @@ class ModelsTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
# Get existing models
|
||||||
|
existing_models = db.query(Model).all()
|
||||||
|
existing_ids = {model.id for model in existing_models}
|
||||||
|
|
||||||
|
# Prepare a set of new model IDs
|
||||||
|
new_model_ids = {model.id for model in models}
|
||||||
|
|
||||||
|
# Update or insert models
|
||||||
|
for model in models:
|
||||||
|
if model.id in existing_ids:
|
||||||
|
db.query(Model).filter_by(id=model.id).update(
|
||||||
|
{
|
||||||
|
**model.model_dump(),
|
||||||
|
"user_id": user_id,
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_model = Model(
|
||||||
|
**{
|
||||||
|
**model.model_dump(),
|
||||||
|
"user_id": user_id,
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.add(new_model)
|
||||||
|
|
||||||
|
# Remove models that are no longer present
|
||||||
|
for model in existing_models:
|
||||||
|
if model.id not in new_model_ids:
|
||||||
|
db.delete(model)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ModelModel.model_validate(model) for model in db.query(Model).all()
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error syncing models for user {user_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
Models = ModelsTable()
|
Models = ModelsTable()
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,10 @@ class UsersTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
return db.query(User).count()
|
return db.query(User).count()
|
||||||
|
|
||||||
|
def has_users(self) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
return db.query(db.query(User).exists()).scalar()
|
||||||
|
|
||||||
def get_first_user(self) -> UserModel:
|
def get_first_user(self) -> UserModel:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
|
||||||
|
|
@ -15,24 +15,28 @@ class DatalabMarkerLoader:
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
langs: Optional[str] = None,
|
api_base_url: str,
|
||||||
|
additional_config: Optional[str] = None,
|
||||||
use_llm: bool = False,
|
use_llm: bool = False,
|
||||||
skip_cache: bool = False,
|
skip_cache: bool = False,
|
||||||
force_ocr: bool = False,
|
force_ocr: bool = False,
|
||||||
paginate: bool = False,
|
paginate: bool = False,
|
||||||
strip_existing_ocr: bool = False,
|
strip_existing_ocr: bool = False,
|
||||||
disable_image_extraction: bool = False,
|
disable_image_extraction: bool = False,
|
||||||
|
format_lines: bool = False,
|
||||||
output_format: str = None,
|
output_format: str = None,
|
||||||
):
|
):
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.langs = langs
|
self.api_base_url = api_base_url
|
||||||
|
self.additional_config = additional_config
|
||||||
self.use_llm = use_llm
|
self.use_llm = use_llm
|
||||||
self.skip_cache = skip_cache
|
self.skip_cache = skip_cache
|
||||||
self.force_ocr = force_ocr
|
self.force_ocr = force_ocr
|
||||||
self.paginate = paginate
|
self.paginate = paginate
|
||||||
self.strip_existing_ocr = strip_existing_ocr
|
self.strip_existing_ocr = strip_existing_ocr
|
||||||
self.disable_image_extraction = disable_image_extraction
|
self.disable_image_extraction = disable_image_extraction
|
||||||
|
self.format_lines = format_lines
|
||||||
self.output_format = output_format
|
self.output_format = output_format
|
||||||
|
|
||||||
def _get_mime_type(self, filename: str) -> str:
|
def _get_mime_type(self, filename: str) -> str:
|
||||||
|
|
@ -60,7 +64,7 @@ class DatalabMarkerLoader:
|
||||||
return mime_map.get(ext, "application/octet-stream")
|
return mime_map.get(ext, "application/octet-stream")
|
||||||
|
|
||||||
def check_marker_request_status(self, request_id: str) -> dict:
|
def check_marker_request_status(self, request_id: str) -> dict:
|
||||||
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
|
url = f"{self.api_base_url}/marker/{request_id}"
|
||||||
headers = {"X-Api-Key": self.api_key}
|
headers = {"X-Api-Key": self.api_key}
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
|
|
@ -81,22 +85,24 @@ class DatalabMarkerLoader:
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
url = "https://www.datalab.to/api/v1/marker"
|
|
||||||
filename = os.path.basename(self.file_path)
|
filename = os.path.basename(self.file_path)
|
||||||
mime_type = self._get_mime_type(filename)
|
mime_type = self._get_mime_type(filename)
|
||||||
headers = {"X-Api-Key": self.api_key}
|
headers = {"X-Api-Key": self.api_key}
|
||||||
|
|
||||||
form_data = {
|
form_data = {
|
||||||
"langs": self.langs,
|
|
||||||
"use_llm": str(self.use_llm).lower(),
|
"use_llm": str(self.use_llm).lower(),
|
||||||
"skip_cache": str(self.skip_cache).lower(),
|
"skip_cache": str(self.skip_cache).lower(),
|
||||||
"force_ocr": str(self.force_ocr).lower(),
|
"force_ocr": str(self.force_ocr).lower(),
|
||||||
"paginate": str(self.paginate).lower(),
|
"paginate": str(self.paginate).lower(),
|
||||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
||||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
||||||
|
"format_lines": str(self.format_lines).lower(),
|
||||||
"output_format": self.output_format,
|
"output_format": self.output_format,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.additional_config and self.additional_config.strip():
|
||||||
|
form_data["additional_config"] = self.additional_config
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||||
)
|
)
|
||||||
|
|
@ -105,7 +111,10 @@ class DatalabMarkerLoader:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, "rb") as f:
|
||||||
files = {"file": (filename, f, mime_type)}
|
files = {"file": (filename, f, mime_type)}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url, data=form_data, files=files, headers=headers
|
f"{self.api_base_url}/marker",
|
||||||
|
data=form_data,
|
||||||
|
files=files,
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
@ -133,11 +142,10 @@ class DatalabMarkerLoader:
|
||||||
|
|
||||||
check_url = result.get("request_check_url")
|
check_url = result.get("request_check_url")
|
||||||
request_id = result.get("request_id")
|
request_id = result.get("request_id")
|
||||||
if not check_url:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Check if this is a direct response (self-hosted) or polling response (DataLab)
|
||||||
|
if check_url:
|
||||||
|
# DataLab polling pattern
|
||||||
for _ in range(300): # Up to 10 minutes
|
for _ in range(300): # Up to 10 minutes
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
try:
|
try:
|
||||||
|
|
@ -185,7 +193,8 @@ class DatalabMarkerLoader:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail="Marker processing timed out",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not poll_result.get("success", False):
|
if not poll_result.get("success", False):
|
||||||
|
|
@ -195,12 +204,30 @@ class DatalabMarkerLoader:
|
||||||
detail=f"Final processing failed: {error_msg}",
|
detail=f"Final processing failed: {error_msg}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DataLab format - content in format-specific fields
|
||||||
content_key = self.output_format.lower()
|
content_key = self.output_format.lower()
|
||||||
raw_content = poll_result.get(content_key)
|
raw_content = poll_result.get(content_key)
|
||||||
|
final_result = poll_result
|
||||||
|
else:
|
||||||
|
# Self-hosted direct response - content in "output" field
|
||||||
|
if "output" in result:
|
||||||
|
log.info("Self-hosted Marker returned direct response without polling")
|
||||||
|
raw_content = result.get("output")
|
||||||
|
final_result = result
|
||||||
|
else:
|
||||||
|
available_fields = (
|
||||||
|
list(result.keys())
|
||||||
|
if isinstance(result, dict)
|
||||||
|
else "non-dict response"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
|
||||||
|
)
|
||||||
|
|
||||||
if content_key == "json":
|
if self.output_format.lower() == "json":
|
||||||
full_text = json.dumps(raw_content, indent=2)
|
full_text = json.dumps(raw_content, indent=2)
|
||||||
elif content_key in {"markdown", "html"}:
|
elif self.output_format.lower() in {"markdown", "html"}:
|
||||||
full_text = str(raw_content).strip()
|
full_text = str(raw_content).strip()
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -211,14 +238,14 @@ class DatalabMarkerLoader:
|
||||||
if not full_text:
|
if not full_text:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Datalab Marker returned empty content",
|
detail="Marker returned empty content",
|
||||||
)
|
)
|
||||||
|
|
||||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
||||||
os.makedirs(marker_output_dir, exist_ok=True)
|
os.makedirs(marker_output_dir, exist_ok=True)
|
||||||
|
|
||||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
||||||
file_ext = file_ext_map.get(content_key, "txt")
|
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
|
||||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
||||||
output_path = os.path.join(marker_output_dir, output_filename)
|
output_path = os.path.join(marker_output_dir, output_filename)
|
||||||
|
|
||||||
|
|
@ -231,13 +258,13 @@ class DatalabMarkerLoader:
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"source": filename,
|
"source": filename,
|
||||||
"output_format": poll_result.get("output_format", self.output_format),
|
"output_format": final_result.get("output_format", self.output_format),
|
||||||
"page_count": poll_result.get("page_count", 0),
|
"page_count": final_result.get("page_count", 0),
|
||||||
"processed_with_llm": self.use_llm,
|
"processed_with_llm": self.use_llm,
|
||||||
"request_id": request_id or "",
|
"request_id": request_id or "",
|
||||||
}
|
}
|
||||||
|
|
||||||
images = poll_result.get("images", {})
|
images = final_result.get("images", {})
|
||||||
if images:
|
if images:
|
||||||
metadata["image_count"] = len(images)
|
metadata["image_count"] = len(images)
|
||||||
metadata["images"] = json.dumps(list(images.keys()))
|
metadata["images"] = json.dumps(list(images.keys()))
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,7 @@ class DoclingLoader:
|
||||||
if lang.strip()
|
if lang.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
endpoint = f"{self.url}/v1alpha/convert/file"
|
endpoint = f"{self.url}/v1/convert/file"
|
||||||
r = requests.post(endpoint, files=files, data=params)
|
r = requests.post(endpoint, files=files, data=params)
|
||||||
|
|
||||||
if r.ok:
|
if r.ok:
|
||||||
|
|
@ -281,10 +281,15 @@ class Loader:
|
||||||
"tiff",
|
"tiff",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
|
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
|
||||||
|
if not api_base_url or api_base_url.strip() == "":
|
||||||
|
api_base_url = "https://www.datalab.to/api/v1"
|
||||||
|
|
||||||
loader = DatalabMarkerLoader(
|
loader = DatalabMarkerLoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
||||||
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
|
api_base_url=api_base_url,
|
||||||
|
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
|
||||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
||||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
||||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
||||||
|
|
@ -295,6 +300,7 @@ class Loader:
|
||||||
disable_image_extraction=self.kwargs.get(
|
disable_image_extraction=self.kwargs.get(
|
||||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
||||||
),
|
),
|
||||||
|
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
|
||||||
output_format=self.kwargs.get(
|
output_format=self.kwargs.get(
|
||||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -508,7 +508,11 @@ def get_sources_from_items(
|
||||||
# Note Attached
|
# Note Attached
|
||||||
note = Notes.get_note_by_id(item.get("id"))
|
note = Notes.get_note_by_id(item.get("id"))
|
||||||
|
|
||||||
if user.role == "admin" or has_access(user.id, "read", note.access_control):
|
if note and (
|
||||||
|
user.role == "admin"
|
||||||
|
or note.user_id == user.id
|
||||||
|
or has_access(user.id, "read", note.access_control)
|
||||||
|
):
|
||||||
# User has access to the note
|
# User has access to the note
|
||||||
query_result = {
|
query_result = {
|
||||||
"documents": [[note.data.get("content", {}).get("md", "")]],
|
"documents": [[note.data.get("content", {}).get("md", "")]],
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ from open_webui.retrieval.vector.main import (
|
||||||
SearchResult,
|
SearchResult,
|
||||||
GetResult,
|
GetResult,
|
||||||
)
|
)
|
||||||
|
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||||
|
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
CHROMA_DATA_PATH,
|
CHROMA_DATA_PATH,
|
||||||
CHROMA_HTTP_HOST,
|
CHROMA_HTTP_HOST,
|
||||||
|
|
@ -144,7 +146,7 @@ class ChromaClient(VectorDBBase):
|
||||||
ids = [item["id"] for item in items]
|
ids = [item["id"] for item in items]
|
||||||
documents = [item["text"] for item in items]
|
documents = [item["text"] for item in items]
|
||||||
embeddings = [item["vector"] for item in items]
|
embeddings = [item["vector"] for item in items]
|
||||||
metadatas = [item["metadata"] for item in items]
|
metadatas = [stringify_metadata(item["metadata"]) for item in items]
|
||||||
|
|
||||||
for batch in create_batches(
|
for batch in create_batches(
|
||||||
api=self.client,
|
api=self.client,
|
||||||
|
|
@ -164,7 +166,7 @@ class ChromaClient(VectorDBBase):
|
||||||
ids = [item["id"] for item in items]
|
ids = [item["id"] for item in items]
|
||||||
documents = [item["text"] for item in items]
|
documents = [item["text"] for item in items]
|
||||||
embeddings = [item["vector"] for item in items]
|
embeddings = [item["vector"] for item in items]
|
||||||
metadatas = [item["metadata"] for item in items]
|
metadatas = [stringify_metadata(item["metadata"]) for item in items]
|
||||||
|
|
||||||
collection.upsert(
|
collection.upsert(
|
||||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ from elasticsearch import Elasticsearch, BadRequestError
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import ssl
|
import ssl
|
||||||
from elasticsearch.helpers import bulk, scan
|
from elasticsearch.helpers import bulk, scan
|
||||||
|
|
||||||
|
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
VectorDBBase,
|
VectorDBBase,
|
||||||
VectorItem,
|
VectorItem,
|
||||||
|
|
@ -243,7 +245,7 @@ class ElasticsearchClient(VectorDBBase):
|
||||||
"collection": collection_name,
|
"collection": collection_name,
|
||||||
"vector": item["vector"],
|
"vector": item["vector"],
|
||||||
"text": item["text"],
|
"text": item["text"],
|
||||||
"metadata": item["metadata"],
|
"metadata": stringify_metadata(item["metadata"]),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for item in batch
|
for item in batch
|
||||||
|
|
@ -264,7 +266,7 @@ class ElasticsearchClient(VectorDBBase):
|
||||||
"collection": collection_name,
|
"collection": collection_name,
|
||||||
"vector": item["vector"],
|
"vector": item["vector"],
|
||||||
"text": item["text"],
|
"text": item["text"],
|
||||||
"metadata": item["metadata"],
|
"metadata": stringify_metadata(item["metadata"]),
|
||||||
},
|
},
|
||||||
"doc_as_upsert": True,
|
"doc_as_upsert": True,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from pymilvus import FieldSchema, DataType
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
VectorDBBase,
|
VectorDBBase,
|
||||||
VectorItem,
|
VectorItem,
|
||||||
|
|
@ -311,7 +313,7 @@ class MilvusClient(VectorDBBase):
|
||||||
"id": item["id"],
|
"id": item["id"],
|
||||||
"vector": item["vector"],
|
"vector": item["vector"],
|
||||||
"data": {"text": item["text"]},
|
"data": {"text": item["text"]},
|
||||||
"metadata": item["metadata"],
|
"metadata": stringify_metadata(item["metadata"]),
|
||||||
}
|
}
|
||||||
for item in items
|
for item in items
|
||||||
],
|
],
|
||||||
|
|
@ -347,7 +349,7 @@ class MilvusClient(VectorDBBase):
|
||||||
"id": item["id"],
|
"id": item["id"],
|
||||||
"vector": item["vector"],
|
"vector": item["vector"],
|
||||||
"data": {"text": item["text"]},
|
"data": {"text": item["text"]},
|
||||||
"metadata": item["metadata"],
|
"metadata": stringify_metadata(item["metadata"]),
|
||||||
}
|
}
|
||||||
for item in items
|
for item in items
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from opensearchpy import OpenSearch
|
||||||
from opensearchpy.helpers import bulk
|
from opensearchpy.helpers import bulk
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
VectorDBBase,
|
VectorDBBase,
|
||||||
VectorItem,
|
VectorItem,
|
||||||
|
|
@ -200,7 +201,7 @@ class OpenSearchClient(VectorDBBase):
|
||||||
"_source": {
|
"_source": {
|
||||||
"vector": item["vector"],
|
"vector": item["vector"],
|
||||||
"text": item["text"],
|
"text": item["text"],
|
||||||
"metadata": item["metadata"],
|
"metadata": stringify_metadata(item["metadata"]),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for item in batch
|
for item in batch
|
||||||
|
|
@ -222,7 +223,7 @@ class OpenSearchClient(VectorDBBase):
|
||||||
"doc": {
|
"doc": {
|
||||||
"vector": item["vector"],
|
"vector": item["vector"],
|
||||||
"text": item["text"],
|
"text": item["text"],
|
||||||
"metadata": item["metadata"],
|
"metadata": stringify_metadata(item["metadata"]),
|
||||||
},
|
},
|
||||||
"doc_as_upsert": True,
|
"doc_as_upsert": True,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
943
backend/open_webui/retrieval/vector/dbs/oracle23ai.py
Normal file
943
backend/open_webui/retrieval/vector/dbs/oracle23ai.py
Normal file
|
|
@ -0,0 +1,943 @@
|
||||||
|
"""
|
||||||
|
Oracle 23ai Vector Database Client - Fixed Version
|
||||||
|
|
||||||
|
# .env
|
||||||
|
VECTOR_DB = "oracle23ai"
|
||||||
|
|
||||||
|
## DBCS or oracle 23ai free
|
||||||
|
ORACLE_DB_USE_WALLET = false
|
||||||
|
ORACLE_DB_USER = "DEMOUSER"
|
||||||
|
ORACLE_DB_PASSWORD = "Welcome123456"
|
||||||
|
ORACLE_DB_DSN = "localhost:1521/FREEPDB1"
|
||||||
|
|
||||||
|
## ADW or ATP
|
||||||
|
# ORACLE_DB_USE_WALLET = true
|
||||||
|
# ORACLE_DB_USER = "DEMOUSER"
|
||||||
|
# ORACLE_DB_PASSWORD = "Welcome123456"
|
||||||
|
# ORACLE_DB_DSN = "medium"
|
||||||
|
# ORACLE_DB_DSN = "(description= (retry_count=3)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=xx.oraclecloud.com))(connect_data=(service_name=yy.adb.oraclecloud.com))(security=(ssl_server_dn_match=no)))"
|
||||||
|
# ORACLE_WALLET_DIR = "/home/opc/adb_wallet"
|
||||||
|
# ORACLE_WALLET_PASSWORD = "Welcome1"
|
||||||
|
|
||||||
|
ORACLE_VECTOR_LENGTH = 768
|
||||||
|
|
||||||
|
ORACLE_DB_POOL_MIN = 2
|
||||||
|
ORACLE_DB_POOL_MAX = 10
|
||||||
|
ORACLE_DB_POOL_INCREMENT = 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any, Union
|
||||||
|
from decimal import Decimal
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import array
|
||||||
|
import oracledb
|
||||||
|
|
||||||
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from open_webui.config import (
|
||||||
|
ORACLE_DB_USE_WALLET,
|
||||||
|
ORACLE_DB_USER,
|
||||||
|
ORACLE_DB_PASSWORD,
|
||||||
|
ORACLE_DB_DSN,
|
||||||
|
ORACLE_WALLET_DIR,
|
||||||
|
ORACLE_WALLET_PASSWORD,
|
||||||
|
ORACLE_VECTOR_LENGTH,
|
||||||
|
ORACLE_DB_POOL_MIN,
|
||||||
|
ORACLE_DB_POOL_MAX,
|
||||||
|
ORACLE_DB_POOL_INCREMENT,
|
||||||
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
class Oracle23aiClient(VectorDBBase):
|
||||||
|
"""
|
||||||
|
Oracle Vector Database Client for vector similarity search using Oracle Database 23ai.
|
||||||
|
|
||||||
|
This client provides an interface to store, retrieve, and search vector embeddings
|
||||||
|
in an Oracle database. It uses connection pooling for efficient database access
|
||||||
|
and supports vector similarity search operations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pool: Connection pool for Oracle database connections
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the Oracle23aiClient with a connection pool.
|
||||||
|
|
||||||
|
Creates a connection pool with configurable min/max connections, initializes
|
||||||
|
the database schema if needed, and sets up necessary tables and indexes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required configuration parameters are missing
|
||||||
|
Exception: If database initialization fails
|
||||||
|
"""
|
||||||
|
self.pool = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the appropriate connection pool based on DB type
|
||||||
|
if ORACLE_DB_USE_WALLET:
|
||||||
|
self._create_adb_pool()
|
||||||
|
else: # DBCS
|
||||||
|
self._create_dbcs_pool()
|
||||||
|
|
||||||
|
dsn = ORACLE_DB_DSN
|
||||||
|
log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]")
|
||||||
|
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
log.info(f"Connection version: {connection.version}")
|
||||||
|
self._initialize_database(connection)
|
||||||
|
|
||||||
|
log.info("Oracle Vector Search initialization complete.")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error during Oracle Vector Search initialization: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_adb_pool(self) -> None:
|
||||||
|
"""
|
||||||
|
Create connection pool for Oracle Autonomous Database.
|
||||||
|
|
||||||
|
Uses wallet-based authentication.
|
||||||
|
"""
|
||||||
|
self.pool = oracledb.create_pool(
|
||||||
|
user=ORACLE_DB_USER,
|
||||||
|
password=ORACLE_DB_PASSWORD,
|
||||||
|
dsn=ORACLE_DB_DSN,
|
||||||
|
min=ORACLE_DB_POOL_MIN,
|
||||||
|
max=ORACLE_DB_POOL_MAX,
|
||||||
|
increment=ORACLE_DB_POOL_INCREMENT,
|
||||||
|
config_dir=ORACLE_WALLET_DIR,
|
||||||
|
wallet_location=ORACLE_WALLET_DIR,
|
||||||
|
wallet_password=ORACLE_WALLET_PASSWORD,
|
||||||
|
)
|
||||||
|
log.info("Created ADB connection pool with wallet authentication.")
|
||||||
|
|
||||||
|
def _create_dbcs_pool(self) -> None:
|
||||||
|
"""
|
||||||
|
Create connection pool for Oracle Database Cloud Service.
|
||||||
|
|
||||||
|
Uses basic authentication without wallet.
|
||||||
|
"""
|
||||||
|
self.pool = oracledb.create_pool(
|
||||||
|
user=ORACLE_DB_USER,
|
||||||
|
password=ORACLE_DB_PASSWORD,
|
||||||
|
dsn=ORACLE_DB_DSN,
|
||||||
|
min=ORACLE_DB_POOL_MIN,
|
||||||
|
max=ORACLE_DB_POOL_MAX,
|
||||||
|
increment=ORACLE_DB_POOL_INCREMENT,
|
||||||
|
)
|
||||||
|
log.info("Created DB connection pool with basic authentication.")
|
||||||
|
|
||||||
|
def get_connection(self):
|
||||||
|
"""
|
||||||
|
Acquire a connection from the connection pool with retry logic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
connection: A database connection with output type handler configured
|
||||||
|
"""
|
||||||
|
max_retries = 3
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
connection = self.pool.acquire()
|
||||||
|
connection.outputtypehandler = self._output_type_handler
|
||||||
|
return connection
|
||||||
|
except oracledb.DatabaseError as e:
|
||||||
|
(error_obj,) = e.args
|
||||||
|
log.exception(
|
||||||
|
f"Connection attempt {attempt + 1} failed: {error_obj.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
wait_time = 2**attempt
|
||||||
|
log.info(f"Retrying in {wait_time} seconds...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def start_health_monitor(self, interval_seconds: int = 60):
|
||||||
|
"""
|
||||||
|
Start a background thread to periodically check the health of the connection pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval_seconds (int): Number of seconds between health checks
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _monitor():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
log.info("[HealthCheck] Running periodic DB health check...")
|
||||||
|
self.ensure_connection()
|
||||||
|
log.info("[HealthCheck] Connection is healthy.")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"[HealthCheck] Connection health check failed: {e}")
|
||||||
|
time.sleep(interval_seconds)
|
||||||
|
|
||||||
|
thread = threading.Thread(target=_monitor, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
log.info(f"Started DB health monitor every {interval_seconds} seconds.")
|
||||||
|
|
||||||
|
def _reconnect_pool(self):
|
||||||
|
"""
|
||||||
|
Attempt to reinitialize the connection pool if it's been closed or broken.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
log.info("Attempting to reinitialize the Oracle connection pool...")
|
||||||
|
|
||||||
|
# Close existing pool if it exists
|
||||||
|
if self.pool:
|
||||||
|
try:
|
||||||
|
self.pool.close()
|
||||||
|
except Exception as close_error:
|
||||||
|
log.warning(f"Error closing existing pool: {close_error}")
|
||||||
|
|
||||||
|
# Re-create the appropriate connection pool based on DB type
|
||||||
|
if ORACLE_DB_USE_WALLET:
|
||||||
|
self._create_adb_pool()
|
||||||
|
else: # DBCS
|
||||||
|
self._create_dbcs_pool()
|
||||||
|
|
||||||
|
log.info("Connection pool reinitialized.")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Failed to reinitialize the connection pool: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def ensure_connection(self):
|
||||||
|
"""
|
||||||
|
Ensure the database connection is alive, reconnecting pool if needed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute("SELECT 1 FROM dual")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(
|
||||||
|
f"Connection check failed: {e}, attempting to reconnect pool..."
|
||||||
|
)
|
||||||
|
self._reconnect_pool()
|
||||||
|
|
||||||
|
def _output_type_handler(self, cursor, metadata):
|
||||||
|
"""
|
||||||
|
Handle Oracle vector type conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cursor: Oracle database cursor
|
||||||
|
metadata: Metadata for the column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A variable with appropriate conversion for vector types
|
||||||
|
"""
|
||||||
|
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
||||||
|
return cursor.var(
|
||||||
|
metadata.type_code, arraysize=cursor.arraysize, outconverter=list
|
||||||
|
)
|
||||||
|
|
||||||
|
def _initialize_database(self, connection) -> None:
|
||||||
|
"""
|
||||||
|
Initialize database schema, tables and indexes.
|
||||||
|
|
||||||
|
Creates the document_chunk table and necessary indexes if they don't exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: Oracle database connection
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If schema initialization fails
|
||||||
|
"""
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
try:
|
||||||
|
log.info("Creating Table document_chunk")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
BEGIN
|
||||||
|
EXECUTE IMMEDIATE '
|
||||||
|
CREATE TABLE IF NOT EXISTS document_chunk (
|
||||||
|
id VARCHAR2(255) PRIMARY KEY,
|
||||||
|
collection_name VARCHAR2(255) NOT NULL,
|
||||||
|
text CLOB,
|
||||||
|
vmetadata JSON,
|
||||||
|
vector vector(*, float32)
|
||||||
|
)
|
||||||
|
';
|
||||||
|
EXCEPTION
|
||||||
|
WHEN OTHERS THEN
|
||||||
|
IF SQLCODE != -955 THEN
|
||||||
|
RAISE;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Creating Index document_chunk_collection_name_idx")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
BEGIN
|
||||||
|
EXECUTE IMMEDIATE '
|
||||||
|
CREATE INDEX IF NOT EXISTS document_chunk_collection_name_idx
|
||||||
|
ON document_chunk (collection_name)
|
||||||
|
';
|
||||||
|
EXCEPTION
|
||||||
|
WHEN OTHERS THEN
|
||||||
|
IF SQLCODE != -955 THEN
|
||||||
|
RAISE;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
BEGIN
|
||||||
|
EXECUTE IMMEDIATE '
|
||||||
|
CREATE VECTOR INDEX IF NOT EXISTS document_chunk_vector_ivf_idx
|
||||||
|
ON document_chunk(vector)
|
||||||
|
ORGANIZATION NEIGHBOR PARTITIONS
|
||||||
|
DISTANCE COSINE
|
||||||
|
WITH TARGET ACCURACY 95
|
||||||
|
PARAMETERS (TYPE IVF, NEIGHBOR PARTITIONS 100)
|
||||||
|
';
|
||||||
|
EXCEPTION
|
||||||
|
WHEN OTHERS THEN
|
||||||
|
IF SQLCODE != -955 THEN
|
||||||
|
RAISE;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
connection.commit()
|
||||||
|
log.info("Database initialization completed successfully.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
connection.rollback()
|
||||||
|
log.exception(f"Error during database initialization: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def check_vector_length(self) -> None:
|
||||||
|
"""
|
||||||
|
Check vector length compatibility (placeholder).
|
||||||
|
|
||||||
|
This method would check if the configured vector length matches the database schema.
|
||||||
|
Currently implemented as a placeholder.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _vector_to_blob(self, vector: List[float]) -> bytes:
|
||||||
|
"""
|
||||||
|
Convert a vector to Oracle BLOB format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector (List[float]): The vector to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The vector in Oracle BLOB format
|
||||||
|
"""
|
||||||
|
return array.array("f", vector)
|
||||||
|
|
||||||
|
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||||
|
"""
|
||||||
|
Adjust vector to the expected length if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector (List[float]): The vector to adjust
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: The adjusted vector
|
||||||
|
"""
|
||||||
|
return vector
|
||||||
|
|
||||||
|
def _decimal_handler(self, obj):
|
||||||
|
"""
|
||||||
|
Handle Decimal objects for JSON serialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Converted decimal value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If object is not JSON serializable
|
||||||
|
"""
|
||||||
|
if isinstance(obj, Decimal):
|
||||||
|
return float(obj)
|
||||||
|
raise TypeError(f"{obj} is not JSON serializable")
|
||||||
|
|
||||||
|
def _metadata_to_json(self, metadata: Dict) -> str:
|
||||||
|
"""
|
||||||
|
Convert metadata dictionary to JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (Dict): Metadata dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON representation of metadata
|
||||||
|
"""
|
||||||
|
return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}"
|
||||||
|
|
||||||
|
def _json_to_metadata(self, json_str: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Convert JSON string to metadata dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str (str): JSON string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Metadata dictionary
|
||||||
|
"""
|
||||||
|
return json.loads(json_str) if json_str else {}
|
||||||
|
|
||||||
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
"""
|
||||||
|
Insert vector items into the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection
|
||||||
|
items (List[VectorItem]): List of vector items to insert
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If insertion fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> items = [
|
||||||
|
... {"id": "1", "text": "Sample text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
|
||||||
|
... {"id": "2", "text": "Another text", "vector": [0.3, 0.4, ...], "metadata": {"source": "doc2"}}
|
||||||
|
... ]
|
||||||
|
>>> client.insert("my_collection", items)
|
||||||
|
"""
|
||||||
|
log.info(f"Inserting {len(items)} items into collection '{collection_name}'.")
|
||||||
|
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
for item in items:
|
||||||
|
vector_blob = self._vector_to_blob(item["vector"])
|
||||||
|
metadata_json = self._metadata_to_json(item["metadata"])
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO document_chunk
|
||||||
|
(id, collection_name, text, vmetadata, vector)
|
||||||
|
VALUES (:id, :collection_name, :text, :metadata, :vector)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"text": item["text"],
|
||||||
|
"metadata": metadata_json,
|
||||||
|
"vector": vector_blob,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
connection.commit()
|
||||||
|
log.info(
|
||||||
|
f"Successfully inserted {len(items)} items into collection '{collection_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
connection.rollback()
|
||||||
|
log.exception(f"Error during insert: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
"""
|
||||||
|
Update or insert vector items into the database.
|
||||||
|
|
||||||
|
If an item with the same ID exists, it will be updated;
|
||||||
|
otherwise, it will be inserted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection
|
||||||
|
items (List[VectorItem]): List of vector items to upsert
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If upsert operation fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> items = [
|
||||||
|
... {"id": "1", "text": "Updated text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
|
||||||
|
... {"id": "3", "text": "New item", "vector": [0.5, 0.6, ...], "metadata": {"source": "doc3"}}
|
||||||
|
... ]
|
||||||
|
>>> client.upsert("my_collection", items)
|
||||||
|
"""
|
||||||
|
log.info(f"Upserting {len(items)} items into collection '{collection_name}'.")
|
||||||
|
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
try:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
for item in items:
|
||||||
|
vector_blob = self._vector_to_blob(item["vector"])
|
||||||
|
metadata_json = self._metadata_to_json(item["metadata"])
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
MERGE INTO document_chunk d
|
||||||
|
USING (SELECT :merge_id as id FROM dual) s
|
||||||
|
ON (d.id = s.id)
|
||||||
|
WHEN MATCHED THEN
|
||||||
|
UPDATE SET
|
||||||
|
collection_name = :upd_collection_name,
|
||||||
|
text = :upd_text,
|
||||||
|
vmetadata = :upd_metadata,
|
||||||
|
vector = :upd_vector
|
||||||
|
WHEN NOT MATCHED THEN
|
||||||
|
INSERT (id, collection_name, text, vmetadata, vector)
|
||||||
|
VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector)
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"merge_id": item["id"],
|
||||||
|
"upd_collection_name": collection_name,
|
||||||
|
"upd_text": item["text"],
|
||||||
|
"upd_metadata": metadata_json,
|
||||||
|
"upd_vector": vector_blob,
|
||||||
|
"ins_id": item["id"],
|
||||||
|
"ins_collection_name": collection_name,
|
||||||
|
"ins_text": item["text"],
|
||||||
|
"ins_metadata": metadata_json,
|
||||||
|
"ins_vector": vector_blob,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
connection.commit()
|
||||||
|
log.info(
|
||||||
|
f"Successfully upserted {len(items)} items into collection '{collection_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
connection.rollback()
|
||||||
|
log.exception(f"Error during upsert: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||||
|
) -> Optional[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search for similar vectors in the database.
|
||||||
|
|
||||||
|
Performs vector similarity search using cosine distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection to search
|
||||||
|
vectors (List[List[Union[float, int]]]): Query vectors to find similar items for
|
||||||
|
limit (int): Maximum number of results to return per query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[SearchResult]: Search results containing ids, distances, documents, and metadata
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> query_vector = [0.1, 0.2, 0.3, ...] # Must match VECTOR_LENGTH
|
||||||
|
>>> results = client.search("my_collection", [query_vector], limit=5)
|
||||||
|
>>> if results:
|
||||||
|
... log.info(f"Found {len(results.ids[0])} matches")
|
||||||
|
... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])):
|
||||||
|
... log.info(f"Match {i+1}: id={id}, distance={dist}")
|
||||||
|
"""
|
||||||
|
log.info(
|
||||||
|
f"Searching items from collection '{collection_name}' with limit {limit}."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not vectors:
|
||||||
|
log.warning("No vectors provided for search.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_queries = len(vectors)
|
||||||
|
|
||||||
|
ids = [[] for _ in range(num_queries)]
|
||||||
|
distances = [[] for _ in range(num_queries)]
|
||||||
|
documents = [[] for _ in range(num_queries)]
|
||||||
|
metadatas = [[] for _ in range(num_queries)]
|
||||||
|
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
for qid, vector in enumerate(vectors):
|
||||||
|
vector_blob = self._vector_to_blob(vector)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT dc.id, dc.text,
|
||||||
|
JSON_SERIALIZE(dc.vmetadata RETURNING VARCHAR2(4096)) as vmetadata,
|
||||||
|
VECTOR_DISTANCE(dc.vector, :query_vector, COSINE) as distance
|
||||||
|
FROM document_chunk dc
|
||||||
|
WHERE dc.collection_name = :collection_name
|
||||||
|
ORDER BY VECTOR_DISTANCE(dc.vector, :query_vector, COSINE)
|
||||||
|
FETCH APPROX FIRST :limit ROWS ONLY
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"query_vector": vector_blob,
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
for row in results:
|
||||||
|
ids[qid].append(row[0])
|
||||||
|
documents[qid].append(
|
||||||
|
row[1].read()
|
||||||
|
if isinstance(row[1], oracledb.LOB)
|
||||||
|
else str(row[1])
|
||||||
|
)
|
||||||
|
# 🔧 FIXED: Parse JSON metadata properly
|
||||||
|
metadata_str = (
|
||||||
|
row[2].read()
|
||||||
|
if isinstance(row[2], oracledb.LOB)
|
||||||
|
else row[2]
|
||||||
|
)
|
||||||
|
metadatas[qid].append(self._json_to_metadata(metadata_str))
|
||||||
|
distances[qid].append(float(row[3]))
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results."
|
||||||
|
)
|
||||||
|
|
||||||
|
return SearchResult(
|
||||||
|
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error during search: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||||
|
) -> Optional[GetResult]:
|
||||||
|
"""
|
||||||
|
Query items based on metadata filters.
|
||||||
|
|
||||||
|
Retrieves items that match specified metadata criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection to query
|
||||||
|
filter (Dict[str, Any]): Metadata filters to apply
|
||||||
|
limit (Optional[int]): Maximum number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[GetResult]: Query results containing ids, documents, and metadata
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> filter = {"source": "doc1", "category": "finance"}
|
||||||
|
>>> results = client.query("my_collection", filter, limit=20)
|
||||||
|
>>> if results:
|
||||||
|
... print(f"Found {len(results.ids[0])} matching documents")
|
||||||
|
"""
|
||||||
|
log.info(f"Querying items from collection '{collection_name}' with filters.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
limit = limit or 100
|
||||||
|
|
||||||
|
query = """
|
||||||
|
SELECT id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
|
||||||
|
FROM document_chunk
|
||||||
|
WHERE collection_name = :collection_name
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = {"collection_name": collection_name}
|
||||||
|
|
||||||
|
for i, (key, value) in enumerate(filter.items()):
|
||||||
|
param_name = f"value_{i}"
|
||||||
|
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
|
||||||
|
params[param_name] = str(value)
|
||||||
|
|
||||||
|
query += " FETCH FIRST :limit ROWS ONLY"
|
||||||
|
params["limit"] = limit
|
||||||
|
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(query, params)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
log.info("No results found for query.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
ids = [[row[0] for row in results]]
|
||||||
|
documents = [
|
||||||
|
[
|
||||||
|
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
|
||||||
|
for row in results
|
||||||
|
]
|
||||||
|
]
|
||||||
|
# 🔧 FIXED: Parse JSON metadata properly
|
||||||
|
metadatas = [
|
||||||
|
[
|
||||||
|
self._json_to_metadata(
|
||||||
|
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
|
||||||
|
)
|
||||||
|
for row in results
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
log.info(f"Query completed. Found {len(results)} results.")
|
||||||
|
|
||||||
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error during query: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
"""
|
||||||
|
Get all items in a collection.
|
||||||
|
|
||||||
|
Retrieves items from a specified collection up to the limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection to retrieve
|
||||||
|
limit (Optional[int]): Maximum number of items to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[GetResult]: Result containing ids, documents, and metadata
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> results = client.get("my_collection", limit=50)
|
||||||
|
>>> if results:
|
||||||
|
... print(f"Retrieved {len(results.ids[0])} documents from collection")
|
||||||
|
"""
|
||||||
|
log.info(
|
||||||
|
f"Getting items from collection '{collection_name}' with limit {limit}."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
limit = limit or 1000
|
||||||
|
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT /*+ MONITOR */ id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
|
||||||
|
FROM document_chunk
|
||||||
|
WHERE collection_name = :collection_name
|
||||||
|
FETCH FIRST :limit ROWS ONLY
|
||||||
|
""",
|
||||||
|
{"collection_name": collection_name, "limit": limit},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
log.info("No results found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
ids = [[row[0] for row in results]]
|
||||||
|
documents = [
|
||||||
|
[
|
||||||
|
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
|
||||||
|
for row in results
|
||||||
|
]
|
||||||
|
]
|
||||||
|
# 🔧 FIXED: Parse JSON metadata properly
|
||||||
|
metadatas = [
|
||||||
|
[
|
||||||
|
self._json_to_metadata(
|
||||||
|
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
|
||||||
|
)
|
||||||
|
for row in results
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error during get: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Delete items from the database.
|
||||||
|
|
||||||
|
Deletes items from a collection based on IDs or metadata filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection to delete from
|
||||||
|
ids (Optional[List[str]]): Specific item IDs to delete
|
||||||
|
filter (Optional[Dict[str, Any]]): Metadata filters for deletion
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If deletion fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> # Delete specific items by ID
|
||||||
|
>>> client.delete("my_collection", ids=["1", "3", "5"])
|
||||||
|
>>> # Or delete by metadata filter
|
||||||
|
>>> client.delete("my_collection", filter={"source": "deprecated_source"})
|
||||||
|
"""
|
||||||
|
log.info(f"Deleting items from collection '{collection_name}'.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = (
|
||||||
|
"DELETE FROM document_chunk WHERE collection_name = :collection_name"
|
||||||
|
)
|
||||||
|
params = {"collection_name": collection_name}
|
||||||
|
|
||||||
|
if ids:
|
||||||
|
# 🔧 FIXED: Use proper parameterized query to prevent SQL injection
|
||||||
|
placeholders = ",".join([f":id_{i}" for i in range(len(ids))])
|
||||||
|
query += f" AND id IN ({placeholders})"
|
||||||
|
for i, id_val in enumerate(ids):
|
||||||
|
params[f"id_{i}"] = id_val
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
for i, (key, value) in enumerate(filter.items()):
|
||||||
|
param_name = f"value_{i}"
|
||||||
|
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
|
||||||
|
params[param_name] = str(value)
|
||||||
|
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(query, params)
|
||||||
|
deleted = cursor.rowcount
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error during delete: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset the database by deleting all items.
|
||||||
|
|
||||||
|
Deletes all items from the document_chunk table.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If reset fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> client.reset() # Warning: Removes all data!
|
||||||
|
"""
|
||||||
|
log.info("Resetting database - deleting all items.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute("DELETE FROM document_chunk")
|
||||||
|
deleted = cursor.rowcount
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error during reset: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close the database connection pool.
|
||||||
|
|
||||||
|
Properly closes the connection pool and releases all resources.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> # After finishing all operations
|
||||||
|
>>> client.close()
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if hasattr(self, "pool") and self.pool:
|
||||||
|
self.pool.close()
|
||||||
|
log.info("Oracle Vector Search connection pool closed.")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error closing connection pool: {e}")
|
||||||
|
|
||||||
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a collection exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the collection exists, False otherwise
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> if client.has_collection("my_collection"):
|
||||||
|
... print("Collection exists!")
|
||||||
|
... else:
|
||||||
|
... print("Collection does not exist.")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM document_chunk
|
||||||
|
WHERE collection_name = :collection_name
|
||||||
|
FETCH FIRST 1 ROWS ONLY
|
||||||
|
""",
|
||||||
|
{"collection_name": collection_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
return count > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error checking collection existence: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete an entire collection.
|
||||||
|
|
||||||
|
Removes all items belonging to the specified collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection to delete
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = Oracle23aiClient()
|
||||||
|
>>> client.delete_collection("obsolete_collection")
|
||||||
|
"""
|
||||||
|
log.info(f"Deleting collection '{collection_name}'.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM document_chunk
|
||||||
|
WHERE collection_name = :collection_name
|
||||||
|
""",
|
||||||
|
{"collection_name": collection_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted = cursor.rowcount
|
||||||
|
connection.commit()
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Collection '{collection_name}' deleted. Removed {deleted} items."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error deleting collection '{collection_name}': {e}")
|
||||||
|
raise
|
||||||
|
|
@ -26,6 +26,8 @@ from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy.ext.mutable import MutableDict
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.retrieval.vector.utils import stringify_metadata
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
VectorDBBase,
|
VectorDBBase,
|
||||||
VectorItem,
|
VectorItem,
|
||||||
|
|
@ -201,6 +203,8 @@ class PgvectorClient(VectorDBBase):
|
||||||
for item in items:
|
for item in items:
|
||||||
vector = self.adjust_vector_length(item["vector"])
|
vector = self.adjust_vector_length(item["vector"])
|
||||||
# Use raw SQL for BYTEA/pgcrypto
|
# Use raw SQL for BYTEA/pgcrypto
|
||||||
|
# Ensure metadata is converted to its JSON text representation
|
||||||
|
json_metadata = json.dumps(item["metadata"])
|
||||||
self.session.execute(
|
self.session.execute(
|
||||||
text(
|
text(
|
||||||
"""
|
"""
|
||||||
|
|
@ -209,7 +213,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
VALUES (
|
VALUES (
|
||||||
:id, :vector, :collection_name,
|
:id, :vector, :collection_name,
|
||||||
pgp_sym_encrypt(:text, :key),
|
pgp_sym_encrypt(:text, :key),
|
||||||
pgp_sym_encrypt(:metadata::text, :key)
|
pgp_sym_encrypt(:metadata_text, :key)
|
||||||
)
|
)
|
||||||
ON CONFLICT (id) DO NOTHING
|
ON CONFLICT (id) DO NOTHING
|
||||||
"""
|
"""
|
||||||
|
|
@ -219,7 +223,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
"vector": vector,
|
"vector": vector,
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
"text": item["text"],
|
"text": item["text"],
|
||||||
"metadata": json.dumps(item["metadata"]),
|
"metadata_text": json_metadata,
|
||||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -235,7 +239,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
vector=vector,
|
vector=vector,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
text=item["text"],
|
text=item["text"],
|
||||||
vmetadata=item["metadata"],
|
vmetadata=stringify_metadata(item["metadata"]),
|
||||||
)
|
)
|
||||||
new_items.append(new_chunk)
|
new_items.append(new_chunk)
|
||||||
self.session.bulk_save_objects(new_items)
|
self.session.bulk_save_objects(new_items)
|
||||||
|
|
@ -253,6 +257,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
if PGVECTOR_PGCRYPTO:
|
if PGVECTOR_PGCRYPTO:
|
||||||
for item in items:
|
for item in items:
|
||||||
vector = self.adjust_vector_length(item["vector"])
|
vector = self.adjust_vector_length(item["vector"])
|
||||||
|
json_metadata = json.dumps(item["metadata"])
|
||||||
self.session.execute(
|
self.session.execute(
|
||||||
text(
|
text(
|
||||||
"""
|
"""
|
||||||
|
|
@ -261,7 +266,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
VALUES (
|
VALUES (
|
||||||
:id, :vector, :collection_name,
|
:id, :vector, :collection_name,
|
||||||
pgp_sym_encrypt(:text, :key),
|
pgp_sym_encrypt(:text, :key),
|
||||||
pgp_sym_encrypt(:metadata::text, :key)
|
pgp_sym_encrypt(:metadata_text, :key)
|
||||||
)
|
)
|
||||||
ON CONFLICT (id) DO UPDATE SET
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
vector = EXCLUDED.vector,
|
vector = EXCLUDED.vector,
|
||||||
|
|
@ -275,7 +280,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
"vector": vector,
|
"vector": vector,
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
"text": item["text"],
|
"text": item["text"],
|
||||||
"metadata": json.dumps(item["metadata"]),
|
"metadata_text": json_metadata,
|
||||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -292,7 +297,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
if existing:
|
if existing:
|
||||||
existing.vector = vector
|
existing.vector = vector
|
||||||
existing.text = item["text"]
|
existing.text = item["text"]
|
||||||
existing.vmetadata = item["metadata"]
|
existing.vmetadata = stringify_metadata(item["metadata"])
|
||||||
existing.collection_name = (
|
existing.collection_name = (
|
||||||
collection_name # Update collection_name if necessary
|
collection_name # Update collection_name if necessary
|
||||||
)
|
)
|
||||||
|
|
@ -302,7 +307,7 @@ class PgvectorClient(VectorDBBase):
|
||||||
vector=vector,
|
vector=vector,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
text=item["text"],
|
text=item["text"],
|
||||||
vmetadata=item["metadata"],
|
vmetadata=stringify_metadata(item["metadata"]),
|
||||||
)
|
)
|
||||||
self.session.add(new_chunk)
|
self.session.add(new_chunk)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
@ -416,10 +421,12 @@ class PgvectorClient(VectorDBBase):
|
||||||
documents[qid].append(row.text)
|
documents[qid].append(row.text)
|
||||||
metadatas[qid].append(row.vmetadata)
|
metadatas[qid].append(row.vmetadata)
|
||||||
|
|
||||||
|
self.session.rollback() # read-only transaction
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.session.rollback()
|
||||||
log.exception(f"Error during search: {e}")
|
log.exception(f"Error during search: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -472,12 +479,14 @@ class PgvectorClient(VectorDBBase):
|
||||||
documents = [[result.text for result in results]]
|
documents = [[result.text for result in results]]
|
||||||
metadatas = [[result.vmetadata for result in results]]
|
metadatas = [[result.vmetadata for result in results]]
|
||||||
|
|
||||||
|
self.session.rollback() # read-only transaction
|
||||||
return GetResult(
|
return GetResult(
|
||||||
ids=ids,
|
ids=ids,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.session.rollback()
|
||||||
log.exception(f"Error during query: {e}")
|
log.exception(f"Error during query: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -518,8 +527,10 @@ class PgvectorClient(VectorDBBase):
|
||||||
documents = [[result.text for result in results]]
|
documents = [[result.text for result in results]]
|
||||||
metadatas = [[result.vmetadata for result in results]]
|
metadatas = [[result.vmetadata for result in results]]
|
||||||
|
|
||||||
|
self.session.rollback() # read-only transaction
|
||||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.session.rollback()
|
||||||
log.exception(f"Error during get: {e}")
|
log.exception(f"Error during get: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -587,8 +598,10 @@ class PgvectorClient(VectorDBBase):
|
||||||
.first()
|
.first()
|
||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
|
self.session.rollback() # read-only transaction
|
||||||
return exists
|
return exists
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.session.rollback()
|
||||||
log.exception(f"Error checking collection existence: {e}")
|
log.exception(f"Error checking collection existence: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ from open_webui.config import (
|
||||||
QDRANT_GRPC_PORT,
|
QDRANT_GRPC_PORT,
|
||||||
QDRANT_PREFER_GRPC,
|
QDRANT_PREFER_GRPC,
|
||||||
QDRANT_COLLECTION_PREFIX,
|
QDRANT_COLLECTION_PREFIX,
|
||||||
|
QDRANT_TIMEOUT,
|
||||||
|
QDRANT_HNSW_M,
|
||||||
)
|
)
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
@ -36,6 +38,8 @@ class QdrantClient(VectorDBBase):
|
||||||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||||
|
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
|
||||||
|
self.QDRANT_HNSW_M = QDRANT_HNSW_M
|
||||||
|
|
||||||
if not self.QDRANT_URI:
|
if not self.QDRANT_URI:
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
@ -53,9 +57,14 @@ class QdrantClient(VectorDBBase):
|
||||||
grpc_port=self.GRPC_PORT,
|
grpc_port=self.GRPC_PORT,
|
||||||
prefer_grpc=self.PREFER_GRPC,
|
prefer_grpc=self.PREFER_GRPC,
|
||||||
api_key=self.QDRANT_API_KEY,
|
api_key=self.QDRANT_API_KEY,
|
||||||
|
timeout=self.QDRANT_TIMEOUT,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
self.client = Qclient(
|
||||||
|
url=self.QDRANT_URI,
|
||||||
|
api_key=self.QDRANT_API_KEY,
|
||||||
|
timeout=QDRANT_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
def _result_to_get_result(self, points) -> GetResult:
|
def _result_to_get_result(self, points) -> GetResult:
|
||||||
ids = []
|
ids = []
|
||||||
|
|
@ -85,6 +94,9 @@ class QdrantClient(VectorDBBase):
|
||||||
distance=models.Distance.COSINE,
|
distance=models.Distance.COSINE,
|
||||||
on_disk=self.QDRANT_ON_DISK,
|
on_disk=self.QDRANT_ON_DISK,
|
||||||
),
|
),
|
||||||
|
hnsw_config=models.HnswConfigDiff(
|
||||||
|
m=self.QDRANT_HNSW_M,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create payload indexes for efficient filtering
|
# Create payload indexes for efficient filtering
|
||||||
|
|
@ -171,23 +183,23 @@ class QdrantClient(VectorDBBase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
points = self.client.query_points(
|
points = self.client.scroll(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
query_filter=models.Filter(should=field_conditions),
|
scroll_filter=models.Filter(should=field_conditions),
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
return self._result_to_get_result(points.points)
|
return self._result_to_get_result(points[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
# Get all the items in the collection.
|
# Get all the items in the collection.
|
||||||
points = self.client.query_points(
|
points = self.client.scroll(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
|
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
|
||||||
)
|
)
|
||||||
return self._result_to_get_result(points.points)
|
return self._result_to_get_result(points[0])
|
||||||
|
|
||||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ from open_webui.config import (
|
||||||
QDRANT_PREFER_GRPC,
|
QDRANT_PREFER_GRPC,
|
||||||
QDRANT_URI,
|
QDRANT_URI,
|
||||||
QDRANT_COLLECTION_PREFIX,
|
QDRANT_COLLECTION_PREFIX,
|
||||||
|
QDRANT_TIMEOUT,
|
||||||
|
QDRANT_HNSW_M,
|
||||||
)
|
)
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
|
|
@ -51,6 +53,8 @@ class QdrantClient(VectorDBBase):
|
||||||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||||
|
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
|
||||||
|
self.QDRANT_HNSW_M = QDRANT_HNSW_M
|
||||||
|
|
||||||
if not self.QDRANT_URI:
|
if not self.QDRANT_URI:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -69,9 +73,14 @@ class QdrantClient(VectorDBBase):
|
||||||
grpc_port=self.GRPC_PORT,
|
grpc_port=self.GRPC_PORT,
|
||||||
prefer_grpc=self.PREFER_GRPC,
|
prefer_grpc=self.PREFER_GRPC,
|
||||||
api_key=self.QDRANT_API_KEY,
|
api_key=self.QDRANT_API_KEY,
|
||||||
|
timeout=self.QDRANT_TIMEOUT,
|
||||||
)
|
)
|
||||||
if self.PREFER_GRPC
|
if self.PREFER_GRPC
|
||||||
else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
else Qclient(
|
||||||
|
url=self.QDRANT_URI,
|
||||||
|
api_key=self.QDRANT_API_KEY,
|
||||||
|
timeout=self.QDRANT_TIMEOUT,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main collection types for multi-tenancy
|
# Main collection types for multi-tenancy
|
||||||
|
|
@ -133,6 +142,12 @@ class QdrantClient(VectorDBBase):
|
||||||
distance=models.Distance.COSINE,
|
distance=models.Distance.COSINE,
|
||||||
on_disk=self.QDRANT_ON_DISK,
|
on_disk=self.QDRANT_ON_DISK,
|
||||||
),
|
),
|
||||||
|
# Disable global index building due to multitenancy
|
||||||
|
# For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance
|
||||||
|
hnsw_config=models.HnswConfigDiff(
|
||||||
|
payload_m=self.QDRANT_HNSW_M,
|
||||||
|
m=0,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
log.info(
|
log.info(
|
||||||
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
|
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
|
||||||
|
|
@ -278,12 +293,12 @@ class QdrantClient(VectorDBBase):
|
||||||
tenant_filter = _tenant_filter(tenant_id)
|
tenant_filter = _tenant_filter(tenant_id)
|
||||||
field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
|
field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
|
||||||
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
|
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
|
||||||
points = self.client.query_points(
|
points = self.client.scroll(
|
||||||
collection_name=mt_collection,
|
collection_name=mt_collection,
|
||||||
query_filter=combined_filter,
|
scroll_filter=combined_filter,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
return self._result_to_get_result(points.points)
|
return self._result_to_get_result(points[0])
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -296,12 +311,12 @@ class QdrantClient(VectorDBBase):
|
||||||
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
|
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
|
||||||
return None
|
return None
|
||||||
tenant_filter = _tenant_filter(tenant_id)
|
tenant_filter = _tenant_filter(tenant_id)
|
||||||
points = self.client.query_points(
|
points = self.client.scroll(
|
||||||
collection_name=mt_collection,
|
collection_name=mt_collection,
|
||||||
query_filter=models.Filter(must=[tenant_filter]),
|
scroll_filter=models.Filter(must=[tenant_filter]),
|
||||||
limit=NO_LIMIT,
|
limit=NO_LIMIT,
|
||||||
)
|
)
|
||||||
return self._result_to_get_result(points.points)
|
return self._result_to_get_result(points[0])
|
||||||
|
|
||||||
def upsert(self, collection_name: str, items: List[VectorItem]):
|
def upsert(self, collection_name: str, items: List[VectorItem]):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
752
backend/open_webui/retrieval/vector/dbs/s3vector.py
Normal file
752
backend/open_webui/retrieval/vector/dbs/s3vector.py
Normal file
|
|
@ -0,0 +1,752 @@
|
||||||
|
from backend.open_webui.retrieval.vector.utils import stringify_metadata
|
||||||
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
GetResult,
|
||||||
|
SearchResult,
|
||||||
|
)
|
||||||
|
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
from typing import List, Optional, Dict, Any, Union
|
||||||
|
import logging
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
class S3VectorClient(VectorDBBase):
|
||||||
|
"""
|
||||||
|
AWS S3 Vector integration for Open WebUI Knowledge.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.bucket_name = S3_VECTOR_BUCKET_NAME
|
||||||
|
self.region = S3_VECTOR_REGION
|
||||||
|
|
||||||
|
# Simple validation - log warnings instead of raising exceptions
|
||||||
|
if not self.bucket_name:
|
||||||
|
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
|
||||||
|
if not self.region:
|
||||||
|
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
|
||||||
|
|
||||||
|
if self.bucket_name and self.region:
|
||||||
|
try:
|
||||||
|
self.client = boto3.client("s3vectors", region_name=self.region)
|
||||||
|
log.info(
|
||||||
|
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to initialize S3Vector client: {e}")
|
||||||
|
self.client = None
|
||||||
|
else:
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
def _create_index(
|
||||||
|
self,
|
||||||
|
index_name: str,
|
||||||
|
dimension: int,
|
||||||
|
data_type: str = "float32",
|
||||||
|
distance_metric: str = "cosine",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a new index in the S3 vector bucket for the given collection if it does not exist.
|
||||||
|
"""
|
||||||
|
if self.has_collection(index_name):
|
||||||
|
log.debug(f"Index '{index_name}' already exists, skipping creation")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.create_index(
|
||||||
|
vectorBucketName=self.bucket_name,
|
||||||
|
indexName=index_name,
|
||||||
|
dataType=data_type,
|
||||||
|
dimension=dimension,
|
||||||
|
distanceMetric=distance_metric,
|
||||||
|
)
|
||||||
|
log.info(
|
||||||
|
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error creating S3 index '{index_name}': {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _filter_metadata(
|
||||||
|
self, metadata: Dict[str, Any], item_id: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
|
||||||
|
"""
|
||||||
|
if not isinstance(metadata, dict) or len(metadata) <= 10:
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
|
||||||
|
important_keys = [
|
||||||
|
"text", # The actual document content
|
||||||
|
"file_id", # File ID
|
||||||
|
"source", # Document source file
|
||||||
|
"title", # Document title
|
||||||
|
"page", # Page number
|
||||||
|
"total_pages", # Total pages in document
|
||||||
|
"embedding_config", # Embedding configuration
|
||||||
|
"created_by", # User who created it
|
||||||
|
"name", # Document name
|
||||||
|
"hash", # Content hash
|
||||||
|
]
|
||||||
|
filtered_metadata = {}
|
||||||
|
|
||||||
|
# First, add important keys if they exist
|
||||||
|
for key in important_keys:
|
||||||
|
if key in metadata:
|
||||||
|
filtered_metadata[key] = metadata[key]
|
||||||
|
if len(filtered_metadata) >= 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
# If we still have room, add other keys
|
||||||
|
if len(filtered_metadata) < 10:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if key not in filtered_metadata:
|
||||||
|
filtered_metadata[key] = value
|
||||||
|
if len(filtered_metadata) >= 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
log.warning(
|
||||||
|
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
|
||||||
|
)
|
||||||
|
return filtered_metadata
|
||||||
|
|
||||||
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a vector index (collection) exists in the S3 vector bucket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||||
|
indexes = response.get("indexes", [])
|
||||||
|
return any(idx.get("indexName") == collection_name for idx in indexes)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error listing indexes: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete an entire S3 Vector index/collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
log.warning(
|
||||||
|
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info(f"Deleting collection '{collection_name}'")
|
||||||
|
self.client.delete_index(
|
||||||
|
vectorBucketName=self.bucket_name, indexName=collection_name
|
||||||
|
)
|
||||||
|
log.info(f"Successfully deleted collection '{collection_name}'")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error deleting collection '{collection_name}': {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
"""
|
||||||
|
Insert vector items into the S3 Vector index. Create index if it does not exist.
|
||||||
|
"""
|
||||||
|
if not items:
|
||||||
|
log.warning("No items to insert")
|
||||||
|
return
|
||||||
|
|
||||||
|
dimension = len(items[0]["vector"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
log.info(f"Index '{collection_name}' does not exist. Creating index.")
|
||||||
|
self._create_index(
|
||||||
|
index_name=collection_name,
|
||||||
|
dimension=dimension,
|
||||||
|
data_type="float32",
|
||||||
|
distance_metric="cosine",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare vectors for insertion
|
||||||
|
vectors = []
|
||||||
|
for item in items:
|
||||||
|
# Ensure vector data is in the correct format for S3 Vector API
|
||||||
|
vector_data = item["vector"]
|
||||||
|
if isinstance(vector_data, list):
|
||||||
|
# Convert list to float32 values as required by S3 Vector API
|
||||||
|
vector_data = [float(x) for x in vector_data]
|
||||||
|
|
||||||
|
# Prepare metadata, ensuring the text field is preserved
|
||||||
|
metadata = item.get("metadata", {}).copy()
|
||||||
|
|
||||||
|
# Add the text field to metadata so it's available for retrieval
|
||||||
|
metadata["text"] = item["text"]
|
||||||
|
|
||||||
|
# Convert metadata to string format for consistency
|
||||||
|
metadata = stringify_metadata(metadata)
|
||||||
|
|
||||||
|
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||||
|
metadata = self._filter_metadata(metadata, item["id"])
|
||||||
|
|
||||||
|
vectors.append(
|
||||||
|
{
|
||||||
|
"key": item["id"],
|
||||||
|
"data": {"float32": vector_data},
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Insert vectors
|
||||||
|
self.client.put_vectors(
|
||||||
|
vectorBucketName=self.bucket_name,
|
||||||
|
indexName=collection_name,
|
||||||
|
vectors=vectors,
|
||||||
|
)
|
||||||
|
log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error inserting vectors: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
"""
|
||||||
|
Insert or update vector items in the S3 Vector index. Create index if it does not exist.
|
||||||
|
"""
|
||||||
|
if not items:
|
||||||
|
log.warning("No items to upsert")
|
||||||
|
return
|
||||||
|
|
||||||
|
dimension = len(items[0]["vector"])
|
||||||
|
log.info(f"Upsert dimension: {dimension}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
log.info(
|
||||||
|
f"Index '{collection_name}' does not exist. Creating index for upsert."
|
||||||
|
)
|
||||||
|
self._create_index(
|
||||||
|
index_name=collection_name,
|
||||||
|
dimension=dimension,
|
||||||
|
data_type="float32",
|
||||||
|
distance_metric="cosine",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare vectors for upsert
|
||||||
|
vectors = []
|
||||||
|
for item in items:
|
||||||
|
# Ensure vector data is in the correct format for S3 Vector API
|
||||||
|
vector_data = item["vector"]
|
||||||
|
if isinstance(vector_data, list):
|
||||||
|
# Convert list to float32 values as required by S3 Vector API
|
||||||
|
vector_data = [float(x) for x in vector_data]
|
||||||
|
|
||||||
|
# Prepare metadata, ensuring the text field is preserved
|
||||||
|
metadata = item.get("metadata", {}).copy()
|
||||||
|
# Add the text field to metadata so it's available for retrieval
|
||||||
|
metadata["text"] = item["text"]
|
||||||
|
|
||||||
|
# Convert metadata to string format for consistency
|
||||||
|
metadata = stringify_metadata(metadata)
|
||||||
|
|
||||||
|
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||||
|
metadata = self._filter_metadata(metadata, item["id"])
|
||||||
|
|
||||||
|
vectors.append(
|
||||||
|
{
|
||||||
|
"key": item["id"],
|
||||||
|
"data": {"float32": vector_data},
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Upsert vectors (using put_vectors for upsert semantics)
|
||||||
|
log.info(
|
||||||
|
f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}"
|
||||||
|
)
|
||||||
|
self.client.put_vectors(
|
||||||
|
vectorBucketName=self.bucket_name,
|
||||||
|
indexName=collection_name,
|
||||||
|
vectors=vectors,
|
||||||
|
)
|
||||||
|
log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error upserting vectors: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||||
|
) -> Optional[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search for similar vectors in a collection using multiple query vectors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
log.warning(f"Collection '{collection_name}' does not exist")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not vectors:
|
||||||
|
log.warning("No query vectors provided")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info(
|
||||||
|
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize result lists
|
||||||
|
all_ids = []
|
||||||
|
all_documents = []
|
||||||
|
all_metadatas = []
|
||||||
|
all_distances = []
|
||||||
|
|
||||||
|
# Process each query vector
|
||||||
|
for i, query_vector in enumerate(vectors):
|
||||||
|
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
|
||||||
|
|
||||||
|
# Prepare the query vector in S3 Vector format
|
||||||
|
query_vector_dict = {"float32": [float(x) for x in query_vector]}
|
||||||
|
|
||||||
|
# Call S3 Vector query API
|
||||||
|
response = self.client.query_vectors(
|
||||||
|
vectorBucketName=self.bucket_name,
|
||||||
|
indexName=collection_name,
|
||||||
|
topK=limit,
|
||||||
|
queryVector=query_vector_dict,
|
||||||
|
returnMetadata=True,
|
||||||
|
returnDistance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process results for this query
|
||||||
|
query_ids = []
|
||||||
|
query_documents = []
|
||||||
|
query_metadatas = []
|
||||||
|
query_distances = []
|
||||||
|
|
||||||
|
result_vectors = response.get("vectors", [])
|
||||||
|
|
||||||
|
for vector in result_vectors:
|
||||||
|
vector_id = vector.get("key")
|
||||||
|
vector_metadata = vector.get("metadata", {})
|
||||||
|
vector_distance = vector.get("distance", 0.0)
|
||||||
|
|
||||||
|
# Extract document text from metadata
|
||||||
|
document_text = ""
|
||||||
|
if isinstance(vector_metadata, dict):
|
||||||
|
# Get the text field first (highest priority)
|
||||||
|
document_text = vector_metadata.get("text")
|
||||||
|
if not document_text:
|
||||||
|
# Fallback to other possible text fields
|
||||||
|
document_text = (
|
||||||
|
vector_metadata.get("content")
|
||||||
|
or vector_metadata.get("document")
|
||||||
|
or vector_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
document_text = vector_id
|
||||||
|
|
||||||
|
query_ids.append(vector_id)
|
||||||
|
query_documents.append(document_text)
|
||||||
|
query_metadatas.append(vector_metadata)
|
||||||
|
query_distances.append(vector_distance)
|
||||||
|
|
||||||
|
# Add this query's results to the overall results
|
||||||
|
all_ids.append(query_ids)
|
||||||
|
all_documents.append(query_documents)
|
||||||
|
all_metadatas.append(query_metadatas)
|
||||||
|
all_distances.append(query_distances)
|
||||||
|
|
||||||
|
log.info(f"Search completed. Found results for {len(all_ids)} queries")
|
||||||
|
|
||||||
|
# Return SearchResult format
|
||||||
|
return SearchResult(
|
||||||
|
ids=all_ids if all_ids else None,
|
||||||
|
documents=all_documents if all_documents else None,
|
||||||
|
metadatas=all_metadatas if all_metadatas else None,
|
||||||
|
distances=all_distances if all_distances else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error searching collection '{collection_name}': {str(e)}")
|
||||||
|
# Handle specific AWS exceptions
|
||||||
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
|
error_code = e.response["Error"]["Code"]
|
||||||
|
if error_code == "NotFoundException":
|
||||||
|
log.warning(f"Collection '{collection_name}' not found")
|
||||||
|
return None
|
||||||
|
elif error_code == "ValidationException":
|
||||||
|
log.error(f"Invalid query vector dimensions or parameters")
|
||||||
|
return None
|
||||||
|
elif error_code == "AccessDeniedException":
|
||||||
|
log.error(
|
||||||
|
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||||
|
) -> Optional[GetResult]:
|
||||||
|
"""
|
||||||
|
Query vectors from a collection using metadata filter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
log.warning(f"Collection '{collection_name}' does not exist")
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
|
if not filter:
|
||||||
|
log.warning("No filter provided, returning all vectors")
|
||||||
|
return self.get(collection_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info(f"Querying collection '{collection_name}' with filter: {filter}")
|
||||||
|
|
||||||
|
# For S3 Vector, we need to use list_vectors and then filter results
|
||||||
|
# Since S3 Vector may not support complex server-side filtering,
|
||||||
|
# we'll retrieve all vectors and filter client-side
|
||||||
|
|
||||||
|
# Get all vectors first
|
||||||
|
all_vectors_result = self.get(collection_name)
|
||||||
|
|
||||||
|
if not all_vectors_result or not all_vectors_result.ids:
|
||||||
|
log.warning("No vectors found in collection")
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
|
# Extract the lists from the result
|
||||||
|
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
|
||||||
|
all_documents = (
|
||||||
|
all_vectors_result.documents[0] if all_vectors_result.documents else []
|
||||||
|
)
|
||||||
|
all_metadatas = (
|
||||||
|
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply client-side filtering
|
||||||
|
filtered_ids = []
|
||||||
|
filtered_documents = []
|
||||||
|
filtered_metadatas = []
|
||||||
|
|
||||||
|
for i, metadata in enumerate(all_metadatas):
|
||||||
|
if self._matches_filter(metadata, filter):
|
||||||
|
if i < len(all_ids):
|
||||||
|
filtered_ids.append(all_ids[i])
|
||||||
|
if i < len(all_documents):
|
||||||
|
filtered_documents.append(all_documents[i])
|
||||||
|
filtered_metadatas.append(metadata)
|
||||||
|
|
||||||
|
# Apply limit if specified
|
||||||
|
if limit and len(filtered_ids) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return GetResult format
|
||||||
|
if filtered_ids:
|
||||||
|
return GetResult(
|
||||||
|
ids=[filtered_ids],
|
||||||
|
documents=[filtered_documents],
|
||||||
|
metadatas=[filtered_metadatas],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error querying collection '{collection_name}': {str(e)}")
|
||||||
|
# Handle specific AWS exceptions
|
||||||
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
|
error_code = e.response["Error"]["Code"]
|
||||||
|
if error_code == "NotFoundException":
|
||||||
|
log.warning(f"Collection '{collection_name}' not found")
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
elif error_code == "AccessDeniedException":
|
||||||
|
log.error(
|
||||||
|
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||||
|
)
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
"""
|
||||||
|
Retrieve all vectors from a collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
log.warning(f"Collection '{collection_name}' does not exist")
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info(f"Retrieving all vectors from collection '{collection_name}'")
|
||||||
|
|
||||||
|
# Initialize result lists
|
||||||
|
all_ids = []
|
||||||
|
all_documents = []
|
||||||
|
all_metadatas = []
|
||||||
|
|
||||||
|
# Handle pagination
|
||||||
|
next_token = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Prepare request parameters
|
||||||
|
request_params = {
|
||||||
|
"vectorBucketName": self.bucket_name,
|
||||||
|
"indexName": collection_name,
|
||||||
|
"returnData": False, # Don't include vector data (not needed for get)
|
||||||
|
"returnMetadata": True, # Include metadata
|
||||||
|
"maxResults": 500, # Use reasonable page size
|
||||||
|
}
|
||||||
|
|
||||||
|
if next_token:
|
||||||
|
request_params["nextToken"] = next_token
|
||||||
|
|
||||||
|
# Call S3 Vector API
|
||||||
|
response = self.client.list_vectors(**request_params)
|
||||||
|
|
||||||
|
# Process vectors in this page
|
||||||
|
vectors = response.get("vectors", [])
|
||||||
|
|
||||||
|
for vector in vectors:
|
||||||
|
vector_id = vector.get("key")
|
||||||
|
vector_data = vector.get("data", {})
|
||||||
|
vector_metadata = vector.get("metadata", {})
|
||||||
|
|
||||||
|
# Extract the actual vector array
|
||||||
|
vector_array = vector_data.get("float32", [])
|
||||||
|
|
||||||
|
# For documents, we try to extract text from metadata or use the vector ID
|
||||||
|
document_text = ""
|
||||||
|
if isinstance(vector_metadata, dict):
|
||||||
|
# Get the text field first (highest priority)
|
||||||
|
document_text = vector_metadata.get("text")
|
||||||
|
if not document_text:
|
||||||
|
# Fallback to other possible text fields
|
||||||
|
document_text = (
|
||||||
|
vector_metadata.get("content")
|
||||||
|
or vector_metadata.get("document")
|
||||||
|
or vector_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log the actual content for debugging
|
||||||
|
log.debug(
|
||||||
|
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
document_text = vector_id
|
||||||
|
|
||||||
|
all_ids.append(vector_id)
|
||||||
|
all_documents.append(document_text)
|
||||||
|
all_metadatas.append(vector_metadata)
|
||||||
|
|
||||||
|
# Check if there are more pages
|
||||||
|
next_token = response.get("nextToken")
|
||||||
|
if not next_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return in GetResult format
|
||||||
|
# The Open WebUI GetResult expects lists of lists, so we wrap each list
|
||||||
|
if all_ids:
|
||||||
|
return GetResult(
|
||||||
|
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
|
||||||
|
)
|
||||||
|
# Handle specific AWS exceptions
|
||||||
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
|
error_code = e.response["Error"]["Code"]
|
||||||
|
if error_code == "NotFoundException":
|
||||||
|
log.warning(f"Collection '{collection_name}' not found")
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
elif error_code == "AccessDeniedException":
|
||||||
|
log.error(
|
||||||
|
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||||
|
)
|
||||||
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
filter: Optional[Dict] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Delete vectors by ID or filter from a collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
log.warning(
|
||||||
|
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if this is a knowledge collection (not file-specific)
|
||||||
|
is_knowledge_collection = not collection_name.startswith("file-")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ids:
|
||||||
|
# Delete by specific vector IDs/keys
|
||||||
|
log.info(
|
||||||
|
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
|
||||||
|
)
|
||||||
|
self.client.delete_vectors(
|
||||||
|
vectorBucketName=self.bucket_name,
|
||||||
|
indexName=collection_name,
|
||||||
|
keys=ids,
|
||||||
|
)
|
||||||
|
log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
|
||||||
|
|
||||||
|
elif filter:
|
||||||
|
# Handle filter-based deletion
|
||||||
|
log.info(
|
||||||
|
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this is a knowledge collection and we have a file_id filter,
|
||||||
|
# also clean up the corresponding file-specific collection
|
||||||
|
if is_knowledge_collection and "file_id" in filter:
|
||||||
|
file_id = filter["file_id"]
|
||||||
|
file_collection_name = f"file-{file_id}"
|
||||||
|
if self.has_collection(file_collection_name):
|
||||||
|
log.info(
|
||||||
|
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
|
||||||
|
)
|
||||||
|
self.delete_collection(file_collection_name)
|
||||||
|
|
||||||
|
# For the main collection, implement query-then-delete
|
||||||
|
# First, query to get IDs matching the filter
|
||||||
|
query_result = self.query(collection_name, filter)
|
||||||
|
if query_result and query_result.ids and query_result.ids[0]:
|
||||||
|
matching_ids = query_result.ids[0]
|
||||||
|
log.info(
|
||||||
|
f"Found {len(matching_ids)} vectors matching filter, deleting them"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the matching vectors by ID
|
||||||
|
self.client.delete_vectors(
|
||||||
|
vectorBucketName=self.bucket_name,
|
||||||
|
indexName=collection_name,
|
||||||
|
keys=matching_ids,
|
||||||
|
)
|
||||||
|
log.info(
|
||||||
|
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.warning("No vectors found matching the filter criteria")
|
||||||
|
else:
|
||||||
|
log.warning("No IDs or filter provided for deletion")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Error deleting vectors from collection '{collection_name}': {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset/clear all vector data. For S3 Vector, this deletes all indexes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.warning(
|
||||||
|
"Reset called - this will delete all vector indexes in the S3 bucket"
|
||||||
|
)
|
||||||
|
|
||||||
|
# List all indexes
|
||||||
|
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||||
|
indexes = response.get("indexes", [])
|
||||||
|
|
||||||
|
if not indexes:
|
||||||
|
log.warning("No indexes found to delete")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Delete all indexes
|
||||||
|
deleted_count = 0
|
||||||
|
for index in indexes:
|
||||||
|
index_name = index.get("indexName")
|
||||||
|
if index_name:
|
||||||
|
try:
|
||||||
|
self.client.delete_index(
|
||||||
|
vectorBucketName=self.bucket_name, indexName=index_name
|
||||||
|
)
|
||||||
|
deleted_count += 1
|
||||||
|
log.info(f"Deleted index: {index_name}")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error deleting index '{index_name}': {e}")
|
||||||
|
|
||||||
|
log.info(f"Reset completed: deleted {deleted_count} indexes")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error during reset: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if metadata matches the given filter conditions.
|
||||||
|
"""
|
||||||
|
if not isinstance(metadata, dict) or not isinstance(filter, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check each filter condition
|
||||||
|
for key, expected_value in filter.items():
|
||||||
|
# Handle special operators
|
||||||
|
if key.startswith("$"):
|
||||||
|
if key == "$and":
|
||||||
|
# All conditions must match
|
||||||
|
if not isinstance(expected_value, list):
|
||||||
|
continue
|
||||||
|
for condition in expected_value:
|
||||||
|
if not self._matches_filter(metadata, condition):
|
||||||
|
return False
|
||||||
|
elif key == "$or":
|
||||||
|
# At least one condition must match
|
||||||
|
if not isinstance(expected_value, list):
|
||||||
|
continue
|
||||||
|
any_match = False
|
||||||
|
for condition in expected_value:
|
||||||
|
if self._matches_filter(metadata, condition):
|
||||||
|
any_match = True
|
||||||
|
break
|
||||||
|
if not any_match:
|
||||||
|
return False
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the actual value from metadata
|
||||||
|
actual_value = metadata.get(key)
|
||||||
|
|
||||||
|
# Handle different types of expected values
|
||||||
|
if isinstance(expected_value, dict):
|
||||||
|
# Handle comparison operators
|
||||||
|
for op, op_value in expected_value.items():
|
||||||
|
if op == "$eq":
|
||||||
|
if actual_value != op_value:
|
||||||
|
return False
|
||||||
|
elif op == "$ne":
|
||||||
|
if actual_value == op_value:
|
||||||
|
return False
|
||||||
|
elif op == "$in":
|
||||||
|
if (
|
||||||
|
not isinstance(op_value, list)
|
||||||
|
or actual_value not in op_value
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
elif op == "$nin":
|
||||||
|
if isinstance(op_value, list) and actual_value in op_value:
|
||||||
|
return False
|
||||||
|
elif op == "$exists":
|
||||||
|
if bool(op_value) != (key in metadata):
|
||||||
|
return False
|
||||||
|
# Add more operators as needed
|
||||||
|
else:
|
||||||
|
# Simple equality check
|
||||||
|
if actual_value != expected_value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
@ -30,6 +30,10 @@ class Vector:
|
||||||
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
||||||
|
|
||||||
return PineconeClient()
|
return PineconeClient()
|
||||||
|
case VectorType.S3VECTOR:
|
||||||
|
from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
|
||||||
|
|
||||||
|
return S3VectorClient()
|
||||||
case VectorType.OPENSEARCH:
|
case VectorType.OPENSEARCH:
|
||||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||||
|
|
||||||
|
|
@ -48,6 +52,10 @@ class Vector:
|
||||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||||
|
|
||||||
return ChromaClient()
|
return ChromaClient()
|
||||||
|
case VectorType.ORACLE23AI:
|
||||||
|
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
|
||||||
|
|
||||||
|
return Oracle23aiClient()
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported vector type: {vector_type}")
|
raise ValueError(f"Unsupported vector type: {vector_type}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,3 +9,5 @@ class VectorType(StrEnum):
|
||||||
ELASTICSEARCH = "elasticsearch"
|
ELASTICSEARCH = "elasticsearch"
|
||||||
OPENSEARCH = "opensearch"
|
OPENSEARCH = "opensearch"
|
||||||
PGVECTOR = "pgvector"
|
PGVECTOR = "pgvector"
|
||||||
|
ORACLE23AI = "oracle23ai"
|
||||||
|
S3VECTOR = "s3vector"
|
||||||
|
|
|
||||||
14
backend/open_webui/retrieval/vector/utils.py
Normal file
14
backend/open_webui/retrieval/vector/utils.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def stringify_metadata(
|
||||||
|
metadata: dict[str, any],
|
||||||
|
) -> dict[str, any]:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if (
|
||||||
|
isinstance(value, datetime)
|
||||||
|
or isinstance(value, list)
|
||||||
|
or isinstance(value, dict)
|
||||||
|
):
|
||||||
|
metadata[key] = str(value)
|
||||||
|
return metadata
|
||||||
|
|
@ -561,7 +561,11 @@ def transcription_handler(request, file_path, metadata):
|
||||||
file_path,
|
file_path,
|
||||||
beam_size=5,
|
beam_size=5,
|
||||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||||
language=metadata.get("language") or WHISPER_LANGUAGE,
|
language=(
|
||||||
|
metadata.get("language", None)
|
||||||
|
if WHISPER_LANGUAGE == ""
|
||||||
|
else WHISPER_LANGUAGE
|
||||||
|
),
|
||||||
)
|
)
|
||||||
log.info(
|
log.info(
|
||||||
"Detected language '%s' with probability %f"
|
"Detected language '%s' with probability %f"
|
||||||
|
|
|
||||||
|
|
@ -351,11 +351,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email)
|
||||||
if not user:
|
if not user:
|
||||||
try:
|
try:
|
||||||
user_count = Users.get_num_users()
|
|
||||||
|
|
||||||
role = (
|
role = (
|
||||||
"admin"
|
"admin"
|
||||||
if user_count == 0
|
if not Users.has_users()
|
||||||
else request.app.state.config.DEFAULT_USER_ROLE
|
else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -489,7 +487,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
if Users.get_user_by_email(admin_email.lower()):
|
if Users.get_user_by_email(admin_email.lower()):
|
||||||
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
||||||
else:
|
else:
|
||||||
if Users.get_num_users() != 0:
|
if Users.has_users():
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||||
|
|
||||||
await signup(
|
await signup(
|
||||||
|
|
@ -556,6 +554,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
|
|
||||||
@router.post("/signup", response_model=SessionUserResponse)
|
@router.post("/signup", response_model=SessionUserResponse)
|
||||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
has_users = Users.has_users()
|
||||||
|
|
||||||
if WEBUI_AUTH:
|
if WEBUI_AUTH:
|
||||||
if (
|
if (
|
||||||
|
|
@ -566,12 +565,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if Users.get_num_users() != 0:
|
if has_users:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
)
|
)
|
||||||
|
|
||||||
user_count = Users.get_num_users()
|
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
|
|
@ -581,9 +579,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
role = (
|
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
|
||||||
)
|
|
||||||
|
|
||||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||||
if len(form_data.password.encode("utf-8")) > 72:
|
if len(form_data.password.encode("utf-8")) > 72:
|
||||||
|
|
@ -644,7 +640,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
user.id, request.app.state.config.USER_PERMISSIONS
|
user.id, request.app.state.config.USER_PERMISSIONS
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_count == 0:
|
if not has_users:
|
||||||
# Disable signup after the first user is created
|
# Disable signup after the first user is created
|
||||||
request.app.state.config.ENABLE_SIGNUP = False
|
request.app.state.config.ENABLE_SIGNUP = False
|
||||||
|
|
||||||
|
|
@ -673,7 +669,7 @@ async def signout(request: Request, response: Response):
|
||||||
|
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
if ENABLE_OAUTH_SIGNUP.value:
|
||||||
oauth_id_token = request.cookies.get("oauth_id_token")
|
oauth_id_token = request.cookies.get("oauth_id_token")
|
||||||
if oauth_id_token:
|
if oauth_id_token and OPENID_PROVIDER_URL.value:
|
||||||
try:
|
try:
|
||||||
async with ClientSession(trust_env=True) as session:
|
async with ClientSession(trust_env=True) as session:
|
||||||
async with session.get(OPENID_PROVIDER_URL.value) as resp:
|
async with session.get(OPENID_PROVIDER_URL.value) as resp:
|
||||||
|
|
|
||||||
|
|
@ -434,13 +434,6 @@ async def update_message_by_id(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.role != "admin" and not has_access(
|
|
||||||
user.id, type="read", access_control=channel.access_control
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
|
||||||
)
|
|
||||||
|
|
||||||
message = Messages.get_message_by_id(message_id)
|
message = Messages.get_message_by_id(message_id)
|
||||||
if not message:
|
if not message:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -452,6 +445,15 @@ async def update_message_by_id(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
user.role != "admin"
|
||||||
|
and message.user_id != user.id
|
||||||
|
and not has_access(user.id, type="read", access_control=channel.access_control)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = Messages.update_message_by_id(message_id, form_data)
|
message = Messages.update_message_by_id(message_id, form_data)
|
||||||
message = Messages.get_message_by_id(message_id)
|
message = Messages.get_message_by_id(message_id)
|
||||||
|
|
@ -641,13 +643,6 @@ async def delete_message_by_id(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.role != "admin" and not has_access(
|
|
||||||
user.id, type="read", access_control=channel.access_control
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
|
||||||
)
|
|
||||||
|
|
||||||
message = Messages.get_message_by_id(message_id)
|
message = Messages.get_message_by_id(message_id)
|
||||||
if not message:
|
if not message:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -659,6 +654,15 @@ async def delete_message_by_id(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
user.role != "admin"
|
||||||
|
and message.user_id != user.id
|
||||||
|
and not has_access(user.id, type="read", access_control=channel.access_control)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Messages.delete_message_by_id(message_id)
|
Messages.delete_message_by_id(message_id)
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
|
|
|
||||||
|
|
@ -617,7 +617,18 @@ async def clone_chat_by_id(
|
||||||
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
chat = Chats.import_chat(
|
||||||
|
user.id,
|
||||||
|
ChatImportForm(
|
||||||
|
**{
|
||||||
|
"chat": updated_chat,
|
||||||
|
"meta": chat.meta,
|
||||||
|
"pinned": chat.pinned,
|
||||||
|
"folder_id": chat.folder_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -646,7 +657,17 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
"title": f"Clone of {chat.title}",
|
"title": f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
chat = Chats.import_chat(
|
||||||
|
user.id,
|
||||||
|
ChatImportForm(
|
||||||
|
**{
|
||||||
|
"chat": updated_chat,
|
||||||
|
"meta": chat.meta,
|
||||||
|
"pinned": chat.pinned,
|
||||||
|
"folder_id": chat.folder_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,11 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.config import get_config, save_config
|
from open_webui.config import get_config, save_config
|
||||||
from open_webui.config import BannerModel
|
from open_webui.config import BannerModel
|
||||||
|
|
||||||
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
|
from open_webui.utils.tools import (
|
||||||
|
get_tool_server_data,
|
||||||
|
get_tool_servers_data,
|
||||||
|
get_tool_server_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -135,7 +139,7 @@ async def verify_tool_servers_config(
|
||||||
elif form_data.auth_type == "session":
|
elif form_data.auth_type == "session":
|
||||||
token = request.state.token.credentials
|
token = request.state.token.credentials
|
||||||
|
|
||||||
url = f"{form_data.url}/{form_data.path}"
|
url = get_tool_server_url(form_data.url, form_data.path)
|
||||||
return await get_tool_server_data(token, url)
|
return await get_tool_server_data(token, url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,9 @@ async def create_feedback(
|
||||||
|
|
||||||
@router.get("/feedback/{id}", response_model=FeedbackModel)
|
@router.get("/feedback/{id}", response_model=FeedbackModel)
|
||||||
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
if user.role == "admin":
|
||||||
|
feedback = Feedbacks.get_feedback_by_id(id=id)
|
||||||
|
else:
|
||||||
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
||||||
|
|
||||||
if not feedback:
|
if not feedback:
|
||||||
|
|
@ -143,6 +146,9 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
async def update_feedback_by_id(
|
async def update_feedback_by_id(
|
||||||
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
|
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
if user.role == "admin":
|
||||||
|
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data)
|
||||||
|
else:
|
||||||
feedback = Feedbacks.update_feedback_by_id_and_user_id(
|
feedback = Feedbacks.update_feedback_by_id_and_user_id(
|
||||||
id=id, user_id=user.id, form_data=form_data
|
id=id, user_id=user.id, form_data=form_data
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -244,11 +244,11 @@ async def delete_folder_by_id(
|
||||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||||
if folder:
|
if folder:
|
||||||
try:
|
try:
|
||||||
result = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||||
if result:
|
for folder_id in folder_ids:
|
||||||
return result
|
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
||||||
else:
|
|
||||||
raise Exception("Error deleting folder")
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error deleting folder: {id}")
|
log.error(f"Error deleting folder: {id}")
|
||||||
|
|
|
||||||
|
|
@ -131,15 +131,29 @@ async def load_function_from_url(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
class SyncFunctionsForm(FunctionForm):
|
class SyncFunctionsForm(BaseModel):
|
||||||
functions: list[FunctionModel] = []
|
functions: list[FunctionModel] = []
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sync", response_model=Optional[FunctionModel])
|
@router.post("/sync", response_model=list[FunctionModel])
|
||||||
async def sync_functions(
|
async def sync_functions(
|
||||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
|
try:
|
||||||
|
for function in form_data.functions:
|
||||||
|
function.content = replace_imports(function.content)
|
||||||
|
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||||
|
function.id,
|
||||||
|
content=function.content,
|
||||||
|
)
|
||||||
|
|
||||||
return Functions.sync_functions(user.id, form_data.functions)
|
return Functions.sync_functions(user.id, form_data.functions)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Failed to load a function: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from open_webui.utils.access_control import has_access, has_permission
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
from open_webui.config import ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
||||||
from open_webui.models.models import Models, ModelForm
|
from open_webui.models.models import Models, ModelForm
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,7 +43,7 @@ router = APIRouter()
|
||||||
async def get_knowledge(user=Depends(get_verified_user)):
|
async def get_knowledge(user=Depends(get_verified_user)):
|
||||||
knowledge_bases = []
|
knowledge_bases = []
|
||||||
|
|
||||||
if user.role == "admin":
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||||
else:
|
else:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
||||||
|
|
@ -90,7 +91,7 @@ async def get_knowledge(user=Depends(get_verified_user)):
|
||||||
async def get_knowledge_list(user=Depends(get_verified_user)):
|
async def get_knowledge_list(user=Depends(get_verified_user)):
|
||||||
knowledge_bases = []
|
knowledge_bases = []
|
||||||
|
|
||||||
if user.role == "admin":
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||||
else:
|
else:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,10 @@ class QueryMemoryForm(BaseModel):
|
||||||
async def query_memory(
|
async def query_memory(
|
||||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
memories = Memories.get_memories_by_user_id(user.id)
|
||||||
|
if not memories:
|
||||||
|
raise HTTPException(status_code=404, detail="No memories found for user")
|
||||||
|
|
||||||
results = VECTOR_DB_CLIENT.search(
|
results = VECTOR_DB_CLIENT.search(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,15 @@ from open_webui.models.models import (
|
||||||
ModelUserResponse,
|
ModelUserResponse,
|
||||||
Models,
|
Models,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
|
from open_webui.config import ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -25,7 +27,7 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=list[ModelUserResponse])
|
@router.get("/", response_model=list[ModelUserResponse])
|
||||||
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
||||||
if user.role == "admin":
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
return Models.get_models()
|
return Models.get_models()
|
||||||
else:
|
else:
|
||||||
return Models.get_models_by_user_id(user.id)
|
return Models.get_models_by_user_id(user.id)
|
||||||
|
|
@ -78,6 +80,32 @@ async def create_new_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# ExportModels
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/export", response_model=list[ModelModel])
|
||||||
|
async def export_models(user=Depends(get_admin_user)):
|
||||||
|
return Models.get_models()
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# SyncModels
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
class SyncModelsForm(BaseModel):
|
||||||
|
models: list[ModelModel] = []
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sync", response_model=list[ModelModel])
|
||||||
|
async def sync_models(
|
||||||
|
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
|
return Models.sync_models(user.id, form_data.models)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
# GetModelById
|
# GetModelById
|
||||||
###########################
|
###########################
|
||||||
|
|
@ -102,7 +130,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# ToggelModelById
|
# ToggleModelById
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,20 @@ import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from typing import Optional
|
||||||
from typing import Literal, Optional, overload
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiocache import cached
|
from aiocache import cached
|
||||||
import requests
|
import requests
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
from fastapi import Depends, HTTPException, Request, APIRouter
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.responses import (
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
FileResponse,
|
||||||
|
StreamingResponse,
|
||||||
|
JSONResponse,
|
||||||
|
PlainTextResponse,
|
||||||
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.background import BackgroundTask
|
from starlette.background import BackgroundTask
|
||||||
|
|
||||||
|
|
@ -31,7 +34,7 @@ from open_webui.env import (
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.payload import (
|
from open_webui.utils.payload import (
|
||||||
|
|
@ -95,12 +98,12 @@ async def cleanup_response(
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
def openai_o_series_handler(payload):
|
def openai_reasoning_model_handler(payload):
|
||||||
"""
|
"""
|
||||||
Handle "o" series specific parameters
|
Handle reasoning model specific parameters
|
||||||
"""
|
"""
|
||||||
if "max_tokens" in payload:
|
if "max_tokens" in payload:
|
||||||
# Convert "max_tokens" to "max_completion_tokens" for all o-series models
|
# Convert "max_tokens" to "max_completion_tokens" for all reasoning models
|
||||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||||
del payload["max_tokens"]
|
del payload["max_tokens"]
|
||||||
|
|
||||||
|
|
@ -362,7 +365,9 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||||
response if isinstance(response, list) else response.get("data", [])
|
response if isinstance(response, list) else response.get("data", [])
|
||||||
):
|
):
|
||||||
if prefix_id:
|
if prefix_id:
|
||||||
model["id"] = f"{prefix_id}.{model['id']}"
|
model["id"] = (
|
||||||
|
f"{prefix_id}.{model.get('id', model.get('name', ''))}"
|
||||||
|
)
|
||||||
|
|
||||||
if tags:
|
if tags:
|
||||||
model["tags"] = tags
|
model["tags"] = tags
|
||||||
|
|
@ -593,15 +598,21 @@ async def verify_connection(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
try:
|
||||||
# Extract response error details if available
|
|
||||||
error_detail = f"HTTP Error: {r.status}"
|
|
||||||
res = await r.json()
|
|
||||||
if "error" in res:
|
|
||||||
error_detail = f"External Error: {res['error']}"
|
|
||||||
raise Exception(error_detail)
|
|
||||||
|
|
||||||
response_data = await r.json()
|
response_data = await r.json()
|
||||||
|
except Exception:
|
||||||
|
response_data = await r.text()
|
||||||
|
|
||||||
|
if r.status != 200:
|
||||||
|
if isinstance(response_data, (dict, list)):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=r.status, content=response_data
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return PlainTextResponse(
|
||||||
|
status_code=r.status, content=response_data
|
||||||
|
)
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
else:
|
else:
|
||||||
headers["Authorization"] = f"Bearer {key}"
|
headers["Authorization"] = f"Bearer {key}"
|
||||||
|
|
@ -611,15 +622,21 @@ async def verify_connection(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
try:
|
||||||
# Extract response error details if available
|
|
||||||
error_detail = f"HTTP Error: {r.status}"
|
|
||||||
res = await r.json()
|
|
||||||
if "error" in res:
|
|
||||||
error_detail = f"External Error: {res['error']}"
|
|
||||||
raise Exception(error_detail)
|
|
||||||
|
|
||||||
response_data = await r.json()
|
response_data = await r.json()
|
||||||
|
except Exception:
|
||||||
|
response_data = await r.text()
|
||||||
|
|
||||||
|
if r.status != 200:
|
||||||
|
if isinstance(response_data, (dict, list)):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=r.status, content=response_data
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return PlainTextResponse(
|
||||||
|
status_code=r.status, content=response_data
|
||||||
|
)
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
|
|
@ -630,8 +647,9 @@ async def verify_connection(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Unexpected error: {e}")
|
log.exception(f"Unexpected error: {e}")
|
||||||
error_detail = f"Unexpected error: {str(e)}"
|
raise HTTPException(
|
||||||
raise HTTPException(status_code=500, detail=error_detail)
|
status_code=500, detail="Open WebUI: Server Connection Error"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_azure_allowed_params(api_version: str) -> set[str]:
|
def get_azure_allowed_params(api_version: str) -> set[str]:
|
||||||
|
|
@ -787,10 +805,12 @@ async def generate_chat_completion(
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||||
|
|
||||||
# Check if model is from "o" series
|
# Check if model is a reasoning model that needs special handling
|
||||||
is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
|
is_reasoning_model = (
|
||||||
if is_o_series:
|
payload["model"].lower().startswith(("o1", "o3", "o4", "gpt-5"))
|
||||||
payload = openai_o_series_handler(payload)
|
)
|
||||||
|
if is_reasoning_model:
|
||||||
|
payload = openai_reasoning_model_handler(payload)
|
||||||
elif "api.openai.com" not in url:
|
elif "api.openai.com" not in url:
|
||||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||||
if "max_completion_tokens" in payload:
|
if "max_completion_tokens" in payload:
|
||||||
|
|
@ -881,21 +901,19 @@ async def generate_chat_completion(
|
||||||
log.error(e)
|
log.error(e)
|
||||||
response = await r.text()
|
response = await r.text()
|
||||||
|
|
||||||
r.raise_for_status()
|
if r.status >= 400:
|
||||||
|
if isinstance(response, (dict, list)):
|
||||||
|
return JSONResponse(status_code=r.status, content=response)
|
||||||
|
else:
|
||||||
|
return PlainTextResponse(status_code=r.status, content=response)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
||||||
detail = None
|
|
||||||
if isinstance(response, dict):
|
|
||||||
if "error" in response:
|
|
||||||
detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
|
|
||||||
elif isinstance(response, str):
|
|
||||||
detail = response
|
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=r.status if r else 500,
|
status_code=r.status if r else 500,
|
||||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
detail="Open WebUI: Server Connection Error",
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if not streaming:
|
if not streaming:
|
||||||
|
|
@ -949,7 +967,7 @@ async def embeddings(request: Request, form_data: dict, user):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
|
||||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||||
streaming = True
|
streaming = True
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
@ -961,21 +979,25 @@ async def embeddings(request: Request, form_data: dict, user):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
response_data = await r.json()
|
response_data = await r.json()
|
||||||
|
except Exception:
|
||||||
|
response_data = await r.text()
|
||||||
|
|
||||||
|
if r.status >= 400:
|
||||||
|
if isinstance(response_data, (dict, list)):
|
||||||
|
return JSONResponse(status_code=r.status, content=response_data)
|
||||||
|
else:
|
||||||
|
return PlainTextResponse(
|
||||||
|
status_code=r.status, content=response_data
|
||||||
|
)
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
detail = None
|
|
||||||
if r is not None:
|
|
||||||
try:
|
|
||||||
res = await r.json()
|
|
||||||
if "error" in res:
|
|
||||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
||||||
except Exception:
|
|
||||||
detail = f"External: {e}"
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=r.status if r else 500,
|
status_code=r.status if r else 500,
|
||||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
detail="Open WebUI: Server Connection Error",
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if not streaming:
|
if not streaming:
|
||||||
|
|
@ -1041,7 +1063,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
# Check if response is SSE
|
# Check if response is SSE
|
||||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||||
|
|
@ -1055,24 +1076,26 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
response_data = await r.json()
|
response_data = await r.json()
|
||||||
|
except Exception:
|
||||||
|
response_data = await r.text()
|
||||||
|
|
||||||
|
if r.status >= 400:
|
||||||
|
if isinstance(response_data, (dict, list)):
|
||||||
|
return JSONResponse(status_code=r.status, content=response_data)
|
||||||
|
else:
|
||||||
|
return PlainTextResponse(
|
||||||
|
status_code=r.status, content=response_data
|
||||||
|
)
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
||||||
detail = None
|
|
||||||
if r is not None:
|
|
||||||
try:
|
|
||||||
res = await r.json()
|
|
||||||
log.error(res)
|
|
||||||
if "error" in res:
|
|
||||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
||||||
except Exception:
|
|
||||||
detail = f"External: {e}"
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=r.status if r else 500,
|
status_code=r.status if r else 500,
|
||||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
detail="Open WebUI: Server Connection Error",
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if not streaming:
|
if not streaming:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
|
|
||||||
from open_webui.models.prompts import (
|
from open_webui.models.prompts import (
|
||||||
PromptForm,
|
PromptForm,
|
||||||
|
|
@ -7,9 +8,9 @@ from open_webui.models.prompts import (
|
||||||
Prompts,
|
Prompts,
|
||||||
)
|
)
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
|
from open_webui.config import ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -20,7 +21,7 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=list[PromptModel])
|
@router.get("/", response_model=list[PromptModel])
|
||||||
async def get_prompts(user=Depends(get_verified_user)):
|
async def get_prompts(user=Depends(get_verified_user)):
|
||||||
if user.role == "admin":
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
prompts = Prompts.get_prompts()
|
prompts = Prompts.get_prompts()
|
||||||
else:
|
else:
|
||||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
||||||
|
|
@ -30,7 +31,7 @@ async def get_prompts(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.get("/list", response_model=list[PromptUserResponse])
|
@router.get("/list", response_model=list[PromptUserResponse])
|
||||||
async def get_prompt_list(user=Depends(get_verified_user)):
|
async def get_prompt_list(user=Depends(get_verified_user)):
|
||||||
if user.role == "admin":
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
prompts = Prompts.get_prompts()
|
prompts = Prompts.get_prompts()
|
||||||
else:
|
else:
|
||||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
||||||
|
|
|
||||||
|
|
@ -401,12 +401,14 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||||
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||||
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
"DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
||||||
|
"DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||||
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||||||
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||||||
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||||||
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||||
|
"DATALAB_MARKER_FORMAT_LINES": request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
|
||||||
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
|
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||||||
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||||||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||||
|
|
@ -566,12 +568,14 @@ class ConfigForm(BaseModel):
|
||||||
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
||||||
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
||||||
DATALAB_MARKER_API_KEY: Optional[str] = None
|
DATALAB_MARKER_API_KEY: Optional[str] = None
|
||||||
DATALAB_MARKER_LANGS: Optional[str] = None
|
DATALAB_MARKER_API_BASE_URL: Optional[str] = None
|
||||||
|
DATALAB_MARKER_ADDITIONAL_CONFIG: Optional[str] = None
|
||||||
DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None
|
DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None
|
||||||
DATALAB_MARKER_FORCE_OCR: Optional[bool] = None
|
DATALAB_MARKER_FORCE_OCR: Optional[bool] = None
|
||||||
DATALAB_MARKER_PAGINATE: Optional[bool] = None
|
DATALAB_MARKER_PAGINATE: Optional[bool] = None
|
||||||
DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None
|
DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None
|
||||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None
|
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None
|
||||||
|
DATALAB_MARKER_FORMAT_LINES: Optional[bool] = None
|
||||||
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
||||||
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
||||||
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
||||||
|
|
@ -683,10 +687,15 @@ async def update_rag_config(
|
||||||
if form_data.DATALAB_MARKER_API_KEY is not None
|
if form_data.DATALAB_MARKER_API_KEY is not None
|
||||||
else request.app.state.config.DATALAB_MARKER_API_KEY
|
else request.app.state.config.DATALAB_MARKER_API_KEY
|
||||||
)
|
)
|
||||||
request.app.state.config.DATALAB_MARKER_LANGS = (
|
request.app.state.config.DATALAB_MARKER_API_BASE_URL = (
|
||||||
form_data.DATALAB_MARKER_LANGS
|
form_data.DATALAB_MARKER_API_BASE_URL
|
||||||
if form_data.DATALAB_MARKER_LANGS is not None
|
if form_data.DATALAB_MARKER_API_BASE_URL is not None
|
||||||
else request.app.state.config.DATALAB_MARKER_LANGS
|
else request.app.state.config.DATALAB_MARKER_API_BASE_URL
|
||||||
|
)
|
||||||
|
request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG = (
|
||||||
|
form_data.DATALAB_MARKER_ADDITIONAL_CONFIG
|
||||||
|
if form_data.DATALAB_MARKER_ADDITIONAL_CONFIG is not None
|
||||||
|
else request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG
|
||||||
)
|
)
|
||||||
request.app.state.config.DATALAB_MARKER_SKIP_CACHE = (
|
request.app.state.config.DATALAB_MARKER_SKIP_CACHE = (
|
||||||
form_data.DATALAB_MARKER_SKIP_CACHE
|
form_data.DATALAB_MARKER_SKIP_CACHE
|
||||||
|
|
@ -713,6 +722,11 @@ async def update_rag_config(
|
||||||
if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None
|
if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None
|
||||||
else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||||
)
|
)
|
||||||
|
request.app.state.config.DATALAB_MARKER_FORMAT_LINES = (
|
||||||
|
form_data.DATALAB_MARKER_FORMAT_LINES
|
||||||
|
if form_data.DATALAB_MARKER_FORMAT_LINES is not None
|
||||||
|
else request.app.state.config.DATALAB_MARKER_FORMAT_LINES
|
||||||
|
)
|
||||||
request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = (
|
request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = (
|
||||||
form_data.DATALAB_MARKER_OUTPUT_FORMAT
|
form_data.DATALAB_MARKER_OUTPUT_FORMAT
|
||||||
if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None
|
if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None
|
||||||
|
|
@ -1006,7 +1020,8 @@ async def update_rag_config(
|
||||||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||||
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||||
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
"DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
||||||
|
"DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||||
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||||||
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||||||
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||||||
|
|
@ -1229,27 +1244,14 @@ def save_docs_to_vector_db(
|
||||||
{
|
{
|
||||||
**doc.metadata,
|
**doc.metadata,
|
||||||
**(metadata if metadata else {}),
|
**(metadata if metadata else {}),
|
||||||
"embedding_config": json.dumps(
|
"embedding_config": {
|
||||||
{
|
|
||||||
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
}
|
},
|
||||||
),
|
|
||||||
}
|
}
|
||||||
for doc in docs
|
for doc in docs
|
||||||
]
|
]
|
||||||
|
|
||||||
# ChromaDB does not like datetime formats
|
|
||||||
# for meta-data so convert them to string.
|
|
||||||
for metadata in metadatas:
|
|
||||||
for key, value in metadata.items():
|
|
||||||
if (
|
|
||||||
isinstance(value, datetime)
|
|
||||||
or isinstance(value, list)
|
|
||||||
or isinstance(value, dict)
|
|
||||||
):
|
|
||||||
metadata[key] = str(value)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
||||||
log.info(f"collection {collection_name} already exists")
|
log.info(f"collection {collection_name} already exists")
|
||||||
|
|
@ -1406,12 +1408,14 @@ def process_file(
|
||||||
loader = Loader(
|
loader = Loader(
|
||||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||||
DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
|
DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||||
DATALAB_MARKER_LANGS=request.app.state.config.DATALAB_MARKER_LANGS,
|
DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
||||||
|
DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||||
DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||||||
DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||||||
DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
|
DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||||||
DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||||
|
DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
|
||||||
DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
|
DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||||||
DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||||||
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||||
|
|
@ -1785,7 +1789,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
request.app.state.config.SERPLY_API_KEY,
|
request.app.state.config.SERPLY_API_KEY,
|
||||||
query,
|
query,
|
||||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
filter_list=request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("No SERPLY_API_KEY found in environment variables")
|
raise Exception("No SERPLY_API_KEY found in environment variables")
|
||||||
|
|
@ -1961,7 +1965,7 @@ async def process_web_search(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for result in search_results
|
for result in search_results
|
||||||
if hasattr(result, "snippet")
|
if hasattr(result, "snippet") and result.snippet is not None
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
loader = get_web_loader(
|
loader = get_web_loader(
|
||||||
|
|
|
||||||
926
backend/open_webui/routers/scim.py
Normal file
926
backend/open_webui/routers/scim.py
Normal file
|
|
@ -0,0 +1,926 @@
|
||||||
|
"""
|
||||||
|
Experimental SCIM 2.0 Implementation for Open WebUI
|
||||||
|
Provides System for Cross-domain Identity Management endpoints for users and groups
|
||||||
|
|
||||||
|
NOTE: This is an experimental implementation and may not fully comply with SCIM 2.0 standards, and is subject to change.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Query, Header, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
|
from open_webui.models.users import Users, UserModel
|
||||||
|
from open_webui.models.groups import Groups, GroupModel
|
||||||
|
from open_webui.utils.auth import (
|
||||||
|
get_admin_user,
|
||||||
|
get_current_user,
|
||||||
|
decode_token,
|
||||||
|
get_verified_user,
|
||||||
|
)
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# SCIM 2.0 Schema URIs
|
||||||
|
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||||
|
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||||
|
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
|
||||||
|
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
|
||||||
|
|
||||||
|
# SCIM Resource Types
|
||||||
|
SCIM_RESOURCE_TYPE_USER = "User"
|
||||||
|
SCIM_RESOURCE_TYPE_GROUP = "Group"
|
||||||
|
|
||||||
|
|
||||||
|
def scim_error(status_code: int, detail: str, scim_type: Optional[str] = None):
|
||||||
|
"""Create a SCIM-compliant error response"""
|
||||||
|
error_body = {
|
||||||
|
"schemas": [SCIM_ERROR_SCHEMA],
|
||||||
|
"status": str(status_code),
|
||||||
|
"detail": detail,
|
||||||
|
}
|
||||||
|
|
||||||
|
if scim_type:
|
||||||
|
error_body["scimType"] = scim_type
|
||||||
|
elif status_code == 404:
|
||||||
|
error_body["scimType"] = "invalidValue"
|
||||||
|
elif status_code == 409:
|
||||||
|
error_body["scimType"] = "uniqueness"
|
||||||
|
elif status_code == 400:
|
||||||
|
error_body["scimType"] = "invalidSyntax"
|
||||||
|
|
||||||
|
return JSONResponse(status_code=status_code, content=error_body)
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMError(BaseModel):
|
||||||
|
"""SCIM Error Response"""
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_ERROR_SCHEMA]
|
||||||
|
status: str
|
||||||
|
scimType: Optional[str] = None
|
||||||
|
detail: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMMeta(BaseModel):
|
||||||
|
"""SCIM Resource Metadata"""
|
||||||
|
|
||||||
|
resourceType: str
|
||||||
|
created: str
|
||||||
|
lastModified: str
|
||||||
|
location: Optional[str] = None
|
||||||
|
version: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMName(BaseModel):
|
||||||
|
"""SCIM User Name"""
|
||||||
|
|
||||||
|
formatted: Optional[str] = None
|
||||||
|
familyName: Optional[str] = None
|
||||||
|
givenName: Optional[str] = None
|
||||||
|
middleName: Optional[str] = None
|
||||||
|
honorificPrefix: Optional[str] = None
|
||||||
|
honorificSuffix: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMEmail(BaseModel):
|
||||||
|
"""SCIM Email"""
|
||||||
|
|
||||||
|
value: str
|
||||||
|
type: Optional[str] = "work"
|
||||||
|
primary: bool = True
|
||||||
|
display: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMPhoto(BaseModel):
|
||||||
|
"""SCIM Photo"""
|
||||||
|
|
||||||
|
value: str
|
||||||
|
type: Optional[str] = "photo"
|
||||||
|
primary: bool = True
|
||||||
|
display: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMGroupMember(BaseModel):
|
||||||
|
"""SCIM Group Member"""
|
||||||
|
|
||||||
|
value: str # User ID
|
||||||
|
ref: Optional[str] = Field(None, alias="$ref")
|
||||||
|
type: Optional[str] = "User"
|
||||||
|
display: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMUser(BaseModel):
|
||||||
|
"""SCIM User Resource"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_USER_SCHEMA]
|
||||||
|
id: str
|
||||||
|
externalId: Optional[str] = None
|
||||||
|
userName: str
|
||||||
|
name: Optional[SCIMName] = None
|
||||||
|
displayName: str
|
||||||
|
emails: List[SCIMEmail]
|
||||||
|
active: bool = True
|
||||||
|
photos: Optional[List[SCIMPhoto]] = None
|
||||||
|
groups: Optional[List[Dict[str, str]]] = None
|
||||||
|
meta: SCIMMeta
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMUserCreateRequest(BaseModel):
|
||||||
|
"""SCIM User Create Request"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_USER_SCHEMA]
|
||||||
|
externalId: Optional[str] = None
|
||||||
|
userName: str
|
||||||
|
name: Optional[SCIMName] = None
|
||||||
|
displayName: str
|
||||||
|
emails: List[SCIMEmail]
|
||||||
|
active: bool = True
|
||||||
|
password: Optional[str] = None
|
||||||
|
photos: Optional[List[SCIMPhoto]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMUserUpdateRequest(BaseModel):
|
||||||
|
"""SCIM User Update Request"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_USER_SCHEMA]
|
||||||
|
id: Optional[str] = None
|
||||||
|
externalId: Optional[str] = None
|
||||||
|
userName: Optional[str] = None
|
||||||
|
name: Optional[SCIMName] = None
|
||||||
|
displayName: Optional[str] = None
|
||||||
|
emails: Optional[List[SCIMEmail]] = None
|
||||||
|
active: Optional[bool] = None
|
||||||
|
photos: Optional[List[SCIMPhoto]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMGroup(BaseModel):
|
||||||
|
"""SCIM Group Resource"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_GROUP_SCHEMA]
|
||||||
|
id: str
|
||||||
|
displayName: str
|
||||||
|
members: Optional[List[SCIMGroupMember]] = []
|
||||||
|
meta: SCIMMeta
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMGroupCreateRequest(BaseModel):
|
||||||
|
"""SCIM Group Create Request"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_GROUP_SCHEMA]
|
||||||
|
displayName: str
|
||||||
|
members: Optional[List[SCIMGroupMember]] = []
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMGroupUpdateRequest(BaseModel):
|
||||||
|
"""SCIM Group Update Request"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_GROUP_SCHEMA]
|
||||||
|
displayName: Optional[str] = None
|
||||||
|
members: Optional[List[SCIMGroupMember]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMListResponse(BaseModel):
|
||||||
|
"""SCIM List Response"""
|
||||||
|
|
||||||
|
schemas: List[str] = [SCIM_LIST_RESPONSE_SCHEMA]
|
||||||
|
totalResults: int
|
||||||
|
itemsPerPage: int
|
||||||
|
startIndex: int
|
||||||
|
Resources: List[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMPatchOperation(BaseModel):
|
||||||
|
"""SCIM Patch Operation"""
|
||||||
|
|
||||||
|
op: str # "add", "replace", "remove"
|
||||||
|
path: Optional[str] = None
|
||||||
|
value: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SCIMPatchRequest(BaseModel):
|
||||||
|
"""SCIM Patch Request"""
|
||||||
|
|
||||||
|
schemas: List[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]
|
||||||
|
Operations: List[SCIMPatchOperation]
|
||||||
|
|
||||||
|
|
||||||
|
def get_scim_auth(
|
||||||
|
request: Request, authorization: Optional[str] = Header(None)
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Verify SCIM authentication
|
||||||
|
Checks for SCIM-specific bearer token configured in the system
|
||||||
|
"""
|
||||||
|
if not authorization:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authorization header required",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parts = authorization.split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid authorization format. Expected: Bearer <token>",
|
||||||
|
)
|
||||||
|
|
||||||
|
scheme, token = parts
|
||||||
|
if scheme.lower() != "bearer":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid authentication scheme",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if SCIM is enabled
|
||||||
|
scim_enabled = getattr(request.app.state, "SCIM_ENABLED", False)
|
||||||
|
log.info(
|
||||||
|
f"SCIM auth check - raw SCIM_ENABLED: {scim_enabled}, type: {type(scim_enabled)}"
|
||||||
|
)
|
||||||
|
# Handle both PersistentConfig and direct value
|
||||||
|
if hasattr(scim_enabled, "value"):
|
||||||
|
scim_enabled = scim_enabled.value
|
||||||
|
log.info(f"SCIM enabled status after conversion: {scim_enabled}")
|
||||||
|
if not scim_enabled:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="SCIM is not enabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the SCIM token
|
||||||
|
scim_token = getattr(request.app.state, "SCIM_TOKEN", None)
|
||||||
|
# Handle both PersistentConfig and direct value
|
||||||
|
if hasattr(scim_token, "value"):
|
||||||
|
scim_token = scim_token.value
|
||||||
|
log.debug(f"SCIM token configured: {bool(scim_token)}")
|
||||||
|
if not scim_token or token != scim_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid SCIM token",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except HTTPException:
|
||||||
|
# Re-raise HTTP exceptions as-is
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"SCIM authentication error: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
log.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
||||||
|
"""Convert internal User model to SCIM User"""
|
||||||
|
# Parse display name into name components
|
||||||
|
name_parts = user.name.split(" ", 1) if user.name else ["", ""]
|
||||||
|
given_name = name_parts[0] if name_parts else ""
|
||||||
|
family_name = name_parts[1] if len(name_parts) > 1 else ""
|
||||||
|
|
||||||
|
# Get user's groups
|
||||||
|
user_groups = Groups.get_groups_by_member_id(user.id)
|
||||||
|
groups = [
|
||||||
|
{
|
||||||
|
"value": group.id,
|
||||||
|
"display": group.name,
|
||||||
|
"$ref": f"{request.base_url}api/v1/scim/v2/Groups/{group.id}",
|
||||||
|
"type": "direct",
|
||||||
|
}
|
||||||
|
for group in user_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
return SCIMUser(
|
||||||
|
id=user.id,
|
||||||
|
userName=user.email,
|
||||||
|
name=SCIMName(
|
||||||
|
formatted=user.name,
|
||||||
|
givenName=given_name,
|
||||||
|
familyName=family_name,
|
||||||
|
),
|
||||||
|
displayName=user.name,
|
||||||
|
emails=[SCIMEmail(value=user.email)],
|
||||||
|
active=user.role != "pending",
|
||||||
|
photos=(
|
||||||
|
[SCIMPhoto(value=user.profile_image_url)]
|
||||||
|
if user.profile_image_url
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
groups=groups if groups else None,
|
||||||
|
meta=SCIMMeta(
|
||||||
|
resourceType=SCIM_RESOURCE_TYPE_USER,
|
||||||
|
created=datetime.fromtimestamp(
|
||||||
|
user.created_at, tz=timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
lastModified=datetime.fromtimestamp(
|
||||||
|
user.updated_at, tz=timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
location=f"{request.base_url}api/v1/scim/v2/Users/{user.id}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
||||||
|
"""Convert internal Group model to SCIM Group"""
|
||||||
|
members = []
|
||||||
|
for user_id in group.user_ids:
|
||||||
|
user = Users.get_user_by_id(user_id)
|
||||||
|
if user:
|
||||||
|
members.append(
|
||||||
|
SCIMGroupMember(
|
||||||
|
value=user.id,
|
||||||
|
ref=f"{request.base_url}api/v1/scim/v2/Users/{user.id}",
|
||||||
|
display=user.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return SCIMGroup(
|
||||||
|
id=group.id,
|
||||||
|
displayName=group.name,
|
||||||
|
members=members,
|
||||||
|
meta=SCIMMeta(
|
||||||
|
resourceType=SCIM_RESOURCE_TYPE_GROUP,
|
||||||
|
created=datetime.fromtimestamp(
|
||||||
|
group.created_at, tz=timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
lastModified=datetime.fromtimestamp(
|
||||||
|
group.updated_at, tz=timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
location=f"{request.base_url}api/v1/scim/v2/Groups/{group.id}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# SCIM Service Provider Config
|
||||||
|
@router.get("/ServiceProviderConfig")
|
||||||
|
async def get_service_provider_config():
|
||||||
|
"""Get SCIM Service Provider Configuration"""
|
||||||
|
return {
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
|
||||||
|
"patch": {"supported": True},
|
||||||
|
"bulk": {"supported": False, "maxOperations": 1000, "maxPayloadSize": 1048576},
|
||||||
|
"filter": {"supported": True, "maxResults": 200},
|
||||||
|
"changePassword": {"supported": False},
|
||||||
|
"sort": {"supported": False},
|
||||||
|
"etag": {"supported": False},
|
||||||
|
"authenticationSchemes": [
|
||||||
|
{
|
||||||
|
"type": "oauthbearertoken",
|
||||||
|
"name": "OAuth Bearer Token",
|
||||||
|
"description": "Authentication using OAuth 2.0 Bearer Token",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# SCIM Resource Types
|
||||||
|
@router.get("/ResourceTypes")
|
||||||
|
async def get_resource_types(request: Request):
|
||||||
|
"""Get SCIM Resource Types"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||||
|
"id": "User",
|
||||||
|
"name": "User",
|
||||||
|
"endpoint": "/Users",
|
||||||
|
"schema": SCIM_USER_SCHEMA,
|
||||||
|
"meta": {
|
||||||
|
"location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/User",
|
||||||
|
"resourceType": "ResourceType",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||||
|
"id": "Group",
|
||||||
|
"name": "Group",
|
||||||
|
"endpoint": "/Groups",
|
||||||
|
"schema": SCIM_GROUP_SCHEMA,
|
||||||
|
"meta": {
|
||||||
|
"location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/Group",
|
||||||
|
"resourceType": "ResourceType",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# SCIM Schemas
|
||||||
|
@router.get("/Schemas")
|
||||||
|
async def get_schemas():
|
||||||
|
"""Get SCIM Schemas"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
|
||||||
|
"id": SCIM_USER_SCHEMA,
|
||||||
|
"name": "User",
|
||||||
|
"description": "User Account",
|
||||||
|
"attributes": [
|
||||||
|
{
|
||||||
|
"name": "userName",
|
||||||
|
"type": "string",
|
||||||
|
"required": True,
|
||||||
|
"uniqueness": "server",
|
||||||
|
},
|
||||||
|
{"name": "displayName", "type": "string", "required": True},
|
||||||
|
{
|
||||||
|
"name": "emails",
|
||||||
|
"type": "complex",
|
||||||
|
"multiValued": True,
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
{"name": "active", "type": "boolean", "required": False},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
|
||||||
|
"id": SCIM_GROUP_SCHEMA,
|
||||||
|
"name": "Group",
|
||||||
|
"description": "Group",
|
||||||
|
"attributes": [
|
||||||
|
{"name": "displayName", "type": "string", "required": True},
|
||||||
|
{
|
||||||
|
"name": "members",
|
||||||
|
"type": "complex",
|
||||||
|
"multiValued": True,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Users endpoints
|
||||||
|
@router.get("/Users", response_model=SCIMListResponse)
|
||||||
|
async def get_users(
|
||||||
|
request: Request,
|
||||||
|
startIndex: int = Query(1, ge=1),
|
||||||
|
count: int = Query(20, ge=1, le=100),
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""List SCIM Users"""
|
||||||
|
skip = startIndex - 1
|
||||||
|
limit = count
|
||||||
|
|
||||||
|
# Get users from database
|
||||||
|
if filter:
|
||||||
|
# Simple filter parsing - supports userName eq "email"
|
||||||
|
# In production, you'd want a more robust filter parser
|
||||||
|
if "userName eq" in filter:
|
||||||
|
email = filter.split('"')[1]
|
||||||
|
user = Users.get_user_by_email(email)
|
||||||
|
users_list = [user] if user else []
|
||||||
|
total = 1 if user else 0
|
||||||
|
else:
|
||||||
|
response = Users.get_users(skip=skip, limit=limit)
|
||||||
|
users_list = response["users"]
|
||||||
|
total = response["total"]
|
||||||
|
else:
|
||||||
|
response = Users.get_users(skip=skip, limit=limit)
|
||||||
|
users_list = response["users"]
|
||||||
|
total = response["total"]
|
||||||
|
|
||||||
|
# Convert to SCIM format
|
||||||
|
scim_users = [user_to_scim(user, request) for user in users_list]
|
||||||
|
|
||||||
|
return SCIMListResponse(
|
||||||
|
totalResults=total,
|
||||||
|
itemsPerPage=len(scim_users),
|
||||||
|
startIndex=startIndex,
|
||||||
|
Resources=scim_users,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/Users/{user_id}", response_model=SCIMUser)
|
||||||
|
async def get_user(
|
||||||
|
user_id: str,
|
||||||
|
request: Request,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Get SCIM User by ID"""
|
||||||
|
user = Users.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
return scim_error(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_to_scim(user, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_user(
|
||||||
|
request: Request,
|
||||||
|
user_data: SCIMUserCreateRequest,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Create SCIM User"""
|
||||||
|
# Check if user already exists
|
||||||
|
existing_user = Users.get_user_by_email(user_data.userName)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"User with email {user_data.userName} already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
email = user_data.emails[0].value if user_data.emails else user_data.userName
|
||||||
|
|
||||||
|
# Parse name if provided
|
||||||
|
name = user_data.displayName
|
||||||
|
if user_data.name:
|
||||||
|
if user_data.name.formatted:
|
||||||
|
name = user_data.name.formatted
|
||||||
|
elif user_data.name.givenName or user_data.name.familyName:
|
||||||
|
name = f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip()
|
||||||
|
|
||||||
|
# Get profile image if provided
|
||||||
|
profile_image = "/user.png"
|
||||||
|
if user_data.photos and len(user_data.photos) > 0:
|
||||||
|
profile_image = user_data.photos[0].value
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
new_user = Users.insert_new_user(
|
||||||
|
id=user_id,
|
||||||
|
name=name,
|
||||||
|
email=email,
|
||||||
|
profile_image_url=profile_image,
|
||||||
|
role="user" if user_data.active else "pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not new_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to create user",
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_to_scim(new_user, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/Users/{user_id}", response_model=SCIMUser)
|
||||||
|
async def update_user(
|
||||||
|
user_id: str,
|
||||||
|
request: Request,
|
||||||
|
user_data: SCIMUserUpdateRequest,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Update SCIM User (full update)"""
|
||||||
|
user = Users.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"User {user_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build update dict
|
||||||
|
update_data = {}
|
||||||
|
|
||||||
|
if user_data.userName:
|
||||||
|
update_data["email"] = user_data.userName
|
||||||
|
|
||||||
|
if user_data.displayName:
|
||||||
|
update_data["name"] = user_data.displayName
|
||||||
|
elif user_data.name:
|
||||||
|
if user_data.name.formatted:
|
||||||
|
update_data["name"] = user_data.name.formatted
|
||||||
|
elif user_data.name.givenName or user_data.name.familyName:
|
||||||
|
update_data["name"] = (
|
||||||
|
f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_data.emails and len(user_data.emails) > 0:
|
||||||
|
update_data["email"] = user_data.emails[0].value
|
||||||
|
|
||||||
|
if user_data.active is not None:
|
||||||
|
update_data["role"] = "user" if user_data.active else "pending"
|
||||||
|
|
||||||
|
if user_data.photos and len(user_data.photos) > 0:
|
||||||
|
update_data["profile_image_url"] = user_data.photos[0].value
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||||
|
if not updated_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to update user",
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_to_scim(updated_user, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/Users/{user_id}", response_model=SCIMUser)
|
||||||
|
async def patch_user(
|
||||||
|
user_id: str,
|
||||||
|
request: Request,
|
||||||
|
patch_data: SCIMPatchRequest,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Update SCIM User (partial update)"""
|
||||||
|
user = Users.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"User {user_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {}
|
||||||
|
|
||||||
|
for operation in patch_data.Operations:
|
||||||
|
op = operation.op.lower()
|
||||||
|
path = operation.path
|
||||||
|
value = operation.value
|
||||||
|
|
||||||
|
if op == "replace":
|
||||||
|
if path == "active":
|
||||||
|
update_data["role"] = "user" if value else "pending"
|
||||||
|
elif path == "userName":
|
||||||
|
update_data["email"] = value
|
||||||
|
elif path == "displayName":
|
||||||
|
update_data["name"] = value
|
||||||
|
elif path == "emails[primary eq true].value":
|
||||||
|
update_data["email"] = value
|
||||||
|
elif path == "name.formatted":
|
||||||
|
update_data["name"] = value
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
if update_data:
|
||||||
|
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||||
|
if not updated_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to update user",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
updated_user = user
|
||||||
|
|
||||||
|
return user_to_scim(updated_user, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_user(
|
||||||
|
user_id: str,
|
||||||
|
request: Request,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Delete SCIM User"""
|
||||||
|
user = Users.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"User {user_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
success = Users.delete_user_by_id(user_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to delete user",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Groups endpoints
|
||||||
|
@router.get("/Groups", response_model=SCIMListResponse)
|
||||||
|
async def get_groups(
|
||||||
|
request: Request,
|
||||||
|
startIndex: int = Query(1, ge=1),
|
||||||
|
count: int = Query(20, ge=1, le=100),
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""List SCIM Groups"""
|
||||||
|
# Get all groups
|
||||||
|
groups_list = Groups.get_groups()
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
total = len(groups_list)
|
||||||
|
start = startIndex - 1
|
||||||
|
end = start + count
|
||||||
|
paginated_groups = groups_list[start:end]
|
||||||
|
|
||||||
|
# Convert to SCIM format
|
||||||
|
scim_groups = [group_to_scim(group, request) for group in paginated_groups]
|
||||||
|
|
||||||
|
return SCIMListResponse(
|
||||||
|
totalResults=total,
|
||||||
|
itemsPerPage=len(scim_groups),
|
||||||
|
startIndex=startIndex,
|
||||||
|
Resources=scim_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/Groups/{group_id}", response_model=SCIMGroup)
|
||||||
|
async def get_group(
|
||||||
|
group_id: str,
|
||||||
|
request: Request,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Get SCIM Group by ID"""
|
||||||
|
group = Groups.get_group_by_id(group_id)
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Group {group_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return group_to_scim(group, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_group(
|
||||||
|
request: Request,
|
||||||
|
group_data: SCIMGroupCreateRequest,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Create SCIM Group"""
|
||||||
|
# Extract member IDs
|
||||||
|
member_ids = []
|
||||||
|
if group_data.members:
|
||||||
|
for member in group_data.members:
|
||||||
|
member_ids.append(member.value)
|
||||||
|
|
||||||
|
# Create group
|
||||||
|
from open_webui.models.groups import GroupForm
|
||||||
|
|
||||||
|
form = GroupForm(
|
||||||
|
name=group_data.displayName,
|
||||||
|
description="",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Need to get the creating user's ID - we'll use the first admin
|
||||||
|
admin_user = Users.get_super_admin_user()
|
||||||
|
if not admin_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="No admin user found",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_group = Groups.insert_new_group(admin_user.id, form)
|
||||||
|
if not new_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to create group",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add members if provided
|
||||||
|
if member_ids:
|
||||||
|
from open_webui.models.groups import GroupUpdateForm
|
||||||
|
|
||||||
|
update_form = GroupUpdateForm(
|
||||||
|
name=new_group.name,
|
||||||
|
description=new_group.description,
|
||||||
|
user_ids=member_ids,
|
||||||
|
)
|
||||||
|
Groups.update_group_by_id(new_group.id, update_form)
|
||||||
|
new_group = Groups.get_group_by_id(new_group.id)
|
||||||
|
|
||||||
|
return group_to_scim(new_group, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
|
||||||
|
async def update_group(
|
||||||
|
group_id: str,
|
||||||
|
request: Request,
|
||||||
|
group_data: SCIMGroupUpdateRequest,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Update SCIM Group (full update)"""
|
||||||
|
group = Groups.get_group_by_id(group_id)
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Group {group_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build update form
|
||||||
|
from open_webui.models.groups import GroupUpdateForm
|
||||||
|
|
||||||
|
update_form = GroupUpdateForm(
|
||||||
|
name=group_data.displayName if group_data.displayName else group.name,
|
||||||
|
description=group.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle members if provided
|
||||||
|
if group_data.members is not None:
|
||||||
|
member_ids = [member.value for member in group_data.members]
|
||||||
|
update_form.user_ids = member_ids
|
||||||
|
|
||||||
|
# Update group
|
||||||
|
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||||
|
if not updated_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to update group",
|
||||||
|
)
|
||||||
|
|
||||||
|
return group_to_scim(updated_group, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
|
||||||
|
async def patch_group(
|
||||||
|
group_id: str,
|
||||||
|
request: Request,
|
||||||
|
patch_data: SCIMPatchRequest,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Update SCIM Group (partial update)"""
|
||||||
|
group = Groups.get_group_by_id(group_id)
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Group {group_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
from open_webui.models.groups import GroupUpdateForm
|
||||||
|
|
||||||
|
update_form = GroupUpdateForm(
|
||||||
|
name=group.name,
|
||||||
|
description=group.description,
|
||||||
|
user_ids=group.user_ids.copy() if group.user_ids else [],
|
||||||
|
)
|
||||||
|
|
||||||
|
for operation in patch_data.Operations:
|
||||||
|
op = operation.op.lower()
|
||||||
|
path = operation.path
|
||||||
|
value = operation.value
|
||||||
|
|
||||||
|
if op == "replace":
|
||||||
|
if path == "displayName":
|
||||||
|
update_form.name = value
|
||||||
|
elif path == "members":
|
||||||
|
# Replace all members
|
||||||
|
update_form.user_ids = [member["value"] for member in value]
|
||||||
|
elif op == "add":
|
||||||
|
if path == "members":
|
||||||
|
# Add members
|
||||||
|
if isinstance(value, list):
|
||||||
|
for member in value:
|
||||||
|
if isinstance(member, dict) and "value" in member:
|
||||||
|
if member["value"] not in update_form.user_ids:
|
||||||
|
update_form.user_ids.append(member["value"])
|
||||||
|
elif op == "remove":
|
||||||
|
if path and path.startswith("members[value eq"):
|
||||||
|
# Remove specific member
|
||||||
|
member_id = path.split('"')[1]
|
||||||
|
if member_id in update_form.user_ids:
|
||||||
|
update_form.user_ids.remove(member_id)
|
||||||
|
|
||||||
|
# Update group
|
||||||
|
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||||
|
if not updated_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to update group",
|
||||||
|
)
|
||||||
|
|
||||||
|
return group_to_scim(updated_group, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_group(
|
||||||
|
group_id: str,
|
||||||
|
request: Request,
|
||||||
|
_: bool = Depends(get_scim_auth),
|
||||||
|
):
|
||||||
|
"""Delete SCIM Group"""
|
||||||
|
group = Groups.get_group_by_id(group_id)
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Group {group_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
success = Groups.delete_group_by_id(group_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to delete group",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
@ -5,6 +5,8 @@ import time
|
||||||
import re
|
import re
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from pydantic import BaseModel, HttpUrl
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.tools import (
|
from open_webui.models.tools import (
|
||||||
ToolForm,
|
ToolForm,
|
||||||
|
|
@ -14,16 +16,15 @@ from open_webui.models.tools import (
|
||||||
Tools,
|
Tools,
|
||||||
)
|
)
|
||||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||||
from open_webui.config import CACHE_DIR
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
||||||
from open_webui.utils.tools import get_tool_specs
|
from open_webui.utils.tools import get_tool_specs
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
|
||||||
|
|
||||||
from open_webui.utils.tools import get_tool_servers_data
|
from open_webui.utils.tools import get_tool_servers_data
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
@ -74,14 +75,16 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.role != "admin":
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
|
# Admin can see all tools
|
||||||
|
return tools
|
||||||
|
else:
|
||||||
tools = [
|
tools = [
|
||||||
tool
|
tool
|
||||||
for tool in tools
|
for tool in tools
|
||||||
if tool.user_id == user.id
|
if tool.user_id == user.id
|
||||||
or has_access(user.id, "read", tool.access_control)
|
or has_access(user.id, "read", tool.access_control)
|
||||||
]
|
]
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -92,7 +95,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.get("/list", response_model=list[ToolUserResponse])
|
@router.get("/list", response_model=list[ToolUserResponse])
|
||||||
async def get_tool_list(user=Depends(get_verified_user)):
|
async def get_tool_list(user=Depends(get_verified_user)):
|
||||||
if user.role == "admin":
|
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||||
tools = Tools.get_tools()
|
tools = Tools.get_tools()
|
||||||
else:
|
else:
|
||||||
tools = Tools.get_tools_by_user_id(user.id, "write")
|
tools = Tools.get_tools_by_user_id(user.id, "write")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from fastapi.responses import Response, StreamingResponse, FileResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.auths import Auths
|
from open_webui.models.auths import Auths
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
|
|
@ -21,9 +29,8 @@ from open_webui.socket.main import (
|
||||||
get_user_active_status,
|
get_user_active_status,
|
||||||
)
|
)
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
||||||
from open_webui.utils.access_control import get_permissions, has_permission
|
from open_webui.utils.access_control import get_permissions, has_permission
|
||||||
|
|
@ -134,7 +141,9 @@ class SharingPermissions(BaseModel):
|
||||||
|
|
||||||
class ChatPermissions(BaseModel):
|
class ChatPermissions(BaseModel):
|
||||||
controls: bool = True
|
controls: bool = True
|
||||||
|
valves: bool = True
|
||||||
system_prompt: bool = True
|
system_prompt: bool = True
|
||||||
|
params: bool = True
|
||||||
file_upload: bool = True
|
file_upload: bool = True
|
||||||
delete: bool = True
|
delete: bool = True
|
||||||
edit: bool = True
|
edit: bool = True
|
||||||
|
|
@ -327,6 +336,43 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetUserProfileImageById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}/profile/image")
|
||||||
|
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
|
user = Users.get_user_by_id(user_id)
|
||||||
|
if user:
|
||||||
|
if user.profile_image_url:
|
||||||
|
# check if it's url or base64
|
||||||
|
if user.profile_image_url.startswith("http"):
|
||||||
|
return Response(
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
headers={"Location": user.profile_image_url},
|
||||||
|
)
|
||||||
|
elif user.profile_image_url.startswith("data:image"):
|
||||||
|
try:
|
||||||
|
header, base64_data = user.profile_image_url.split(",", 1)
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
image_buffer = io.BytesIO(image_data)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
image_buffer,
|
||||||
|
media_type="image/png",
|
||||||
|
headers={"Content-Disposition": "inline; filename=image.png"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
return FileResponse(f"{STATIC_DIR}/user.png")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetUserActiveStatusById
|
# GetUserActiveStatusById
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,11 @@ from open_webui.env import (
|
||||||
ENABLE_WEBSOCKET_SUPPORT,
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
WEBSOCKET_MANAGER,
|
WEBSOCKET_MANAGER,
|
||||||
WEBSOCKET_REDIS_URL,
|
WEBSOCKET_REDIS_URL,
|
||||||
|
WEBSOCKET_REDIS_CLUSTER,
|
||||||
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||||
WEBSOCKET_SENTINEL_PORT,
|
WEBSOCKET_SENTINEL_PORT,
|
||||||
WEBSOCKET_SENTINEL_HOSTS,
|
WEBSOCKET_SENTINEL_HOSTS,
|
||||||
|
REDIS_KEY_PREFIX,
|
||||||
)
|
)
|
||||||
from open_webui.utils.auth import decode_token
|
from open_webui.utils.auth import decode_token
|
||||||
from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
|
from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
|
||||||
|
|
@ -85,6 +87,7 @@ if WEBSOCKET_MANAGER == "redis":
|
||||||
redis_sentinels=get_sentinels_from_env(
|
redis_sentinels=get_sentinels_from_env(
|
||||||
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||||
),
|
),
|
||||||
|
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||||
async_mode=True,
|
async_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -92,19 +95,22 @@ if WEBSOCKET_MANAGER == "redis":
|
||||||
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||||
)
|
)
|
||||||
SESSION_POOL = RedisDict(
|
SESSION_POOL = RedisDict(
|
||||||
"open-webui:session_pool",
|
f"{REDIS_KEY_PREFIX}:session_pool",
|
||||||
redis_url=WEBSOCKET_REDIS_URL,
|
redis_url=WEBSOCKET_REDIS_URL,
|
||||||
redis_sentinels=redis_sentinels,
|
redis_sentinels=redis_sentinels,
|
||||||
|
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||||
)
|
)
|
||||||
USER_POOL = RedisDict(
|
USER_POOL = RedisDict(
|
||||||
"open-webui:user_pool",
|
f"{REDIS_KEY_PREFIX}:user_pool",
|
||||||
redis_url=WEBSOCKET_REDIS_URL,
|
redis_url=WEBSOCKET_REDIS_URL,
|
||||||
redis_sentinels=redis_sentinels,
|
redis_sentinels=redis_sentinels,
|
||||||
|
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||||
)
|
)
|
||||||
USAGE_POOL = RedisDict(
|
USAGE_POOL = RedisDict(
|
||||||
"open-webui:usage_pool",
|
f"{REDIS_KEY_PREFIX}:usage_pool",
|
||||||
redis_url=WEBSOCKET_REDIS_URL,
|
redis_url=WEBSOCKET_REDIS_URL,
|
||||||
redis_sentinels=redis_sentinels,
|
redis_sentinels=redis_sentinels,
|
||||||
|
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||||
)
|
)
|
||||||
|
|
||||||
clean_up_lock = RedisLock(
|
clean_up_lock = RedisLock(
|
||||||
|
|
@ -112,6 +118,7 @@ if WEBSOCKET_MANAGER == "redis":
|
||||||
lock_name="usage_cleanup_lock",
|
lock_name="usage_cleanup_lock",
|
||||||
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||||
redis_sentinels=redis_sentinels,
|
redis_sentinels=redis_sentinels,
|
||||||
|
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||||
)
|
)
|
||||||
aquire_func = clean_up_lock.aquire_lock
|
aquire_func = clean_up_lock.aquire_lock
|
||||||
renew_func = clean_up_lock.renew_lock
|
renew_func = clean_up_lock.renew_lock
|
||||||
|
|
@ -126,7 +133,7 @@ else:
|
||||||
|
|
||||||
YDOC_MANAGER = YdocManager(
|
YDOC_MANAGER = YdocManager(
|
||||||
redis=REDIS,
|
redis=REDIS,
|
||||||
redis_key_prefix="open-webui:ydoc:documents",
|
redis_key_prefix=f"{REDIS_KEY_PREFIX}:ydoc:documents",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -581,7 +588,7 @@ async def yjs_document_leave(sid, data):
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
YDOC_MANAGER.document_exists(document_id)
|
await YDOC_MANAGER.document_exists(document_id)
|
||||||
and len(await YDOC_MANAGER.get_users(document_id)) == 0
|
and len(await YDOC_MANAGER.get_users(document_id)) == 0
|
||||||
):
|
):
|
||||||
log.info(f"Cleaning up document {document_id} as no users are left")
|
log.info(f"Cleaning up document {document_id} as no users are left")
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,30 @@
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from open_webui.utils.redis import get_redis_connection
|
from open_webui.utils.redis import get_redis_connection
|
||||||
|
from open_webui.env import REDIS_KEY_PREFIX
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
import pycrdt as Y
|
import pycrdt as Y
|
||||||
|
|
||||||
|
|
||||||
class RedisLock:
|
class RedisLock:
|
||||||
def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_url,
|
||||||
|
lock_name,
|
||||||
|
timeout_secs,
|
||||||
|
redis_sentinels=[],
|
||||||
|
redis_cluster=False,
|
||||||
|
):
|
||||||
|
|
||||||
self.lock_name = lock_name
|
self.lock_name = lock_name
|
||||||
self.lock_id = str(uuid.uuid4())
|
self.lock_id = str(uuid.uuid4())
|
||||||
self.timeout_secs = timeout_secs
|
self.timeout_secs = timeout_secs
|
||||||
self.lock_obtained = False
|
self.lock_obtained = False
|
||||||
self.redis = get_redis_connection(
|
self.redis = get_redis_connection(
|
||||||
redis_url, redis_sentinels, decode_responses=True
|
redis_url,
|
||||||
|
redis_sentinels,
|
||||||
|
redis_cluster=redis_cluster,
|
||||||
|
decode_responses=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def aquire_lock(self):
|
def aquire_lock(self):
|
||||||
|
|
@ -35,10 +47,13 @@ class RedisLock:
|
||||||
|
|
||||||
|
|
||||||
class RedisDict:
|
class RedisDict:
|
||||||
def __init__(self, name, redis_url, redis_sentinels=[]):
|
def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.redis = get_redis_connection(
|
self.redis = get_redis_connection(
|
||||||
redis_url, redis_sentinels, decode_responses=True
|
redis_url,
|
||||||
|
redis_sentinels,
|
||||||
|
redis_cluster=redis_cluster,
|
||||||
|
decode_responses=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
|
@ -97,7 +112,7 @@ class YdocManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis=None,
|
redis=None,
|
||||||
redis_key_prefix: str = "open-webui:ydoc:documents",
|
redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents",
|
||||||
):
|
):
|
||||||
self._updates = {}
|
self._updates = {}
|
||||||
self._users = {}
|
self._users = {}
|
||||||
|
|
|
||||||
BIN
backend/open_webui/static/user.png
Normal file
BIN
backend/open_webui/static/user.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.7 KiB |
|
|
@ -8,7 +8,7 @@ from redis.asyncio import Redis
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -19,9 +19,9 @@ tasks: Dict[str, asyncio.Task] = {}
|
||||||
item_tasks = {}
|
item_tasks = {}
|
||||||
|
|
||||||
|
|
||||||
REDIS_TASKS_KEY = "open-webui:tasks"
|
REDIS_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks"
|
||||||
REDIS_ITEM_TASKS_KEY = "open-webui:tasks:item"
|
REDIS_ITEM_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks:item"
|
||||||
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
|
REDIS_PUBSUB_CHANNEL = f"{REDIS_KEY_PREFIX}:tasks:commands"
|
||||||
|
|
||||||
|
|
||||||
async def redis_task_command_listener(app):
|
async def redis_task_command_listener(app):
|
||||||
|
|
|
||||||
|
|
@ -221,7 +221,7 @@ def get_current_user(
|
||||||
token = request.cookies.get("token")
|
token = request.cookies.get("token")
|
||||||
|
|
||||||
if token is None:
|
if token is None:
|
||||||
raise HTTPException(status_code=403, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|
||||||
# auth by api key
|
# auth by api key
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,6 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
AUDIT_UVICORN_LOGGER_NAMES,
|
AUDIT_UVICORN_LOGGER_NAMES,
|
||||||
AUDIT_LOG_FILE_ROTATION_SIZE,
|
AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||||
|
|
@ -14,6 +12,7 @@ from open_webui.env import (
|
||||||
AUDIT_LOGS_FILE_PATH,
|
AUDIT_LOGS_FILE_PATH,
|
||||||
GLOBAL_LOG_LEVEL,
|
GLOBAL_LOG_LEVEL,
|
||||||
ENABLE_OTEL,
|
ENABLE_OTEL,
|
||||||
|
ENABLE_OTEL_LOGS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,13 +29,16 @@ def stdout_format(record: "Record") -> str:
|
||||||
Returns:
|
Returns:
|
||||||
str: A formatted log string intended for stdout.
|
str: A formatted log string intended for stdout.
|
||||||
"""
|
"""
|
||||||
|
if record["extra"]:
|
||||||
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||||
|
extra_format = " - {extra[extra_json]}"
|
||||||
|
else:
|
||||||
|
extra_format = ""
|
||||||
return (
|
return (
|
||||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||||
"<level>{level: <8}</level> | "
|
"<level>{level: <8}</level> | "
|
||||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||||
"<level>{message}</level> - {extra[extra_json]}"
|
"<level>{message}</level>" + extra_format + "\n{exception}"
|
||||||
"\n{exception}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -65,6 +67,10 @@ class InterceptHandler(logging.Handler):
|
||||||
logger.opt(depth=depth, exception=record.exc_info).bind(
|
logger.opt(depth=depth, exception=record.exc_info).bind(
|
||||||
**self._get_extras()
|
**self._get_extras()
|
||||||
).log(level, record.getMessage())
|
).log(level, record.getMessage())
|
||||||
|
if ENABLE_OTEL and ENABLE_OTEL_LOGS:
|
||||||
|
from open_webui.utils.telemetry.logs import otel_handler
|
||||||
|
|
||||||
|
otel_handler.emit(record)
|
||||||
|
|
||||||
def _get_extras(self):
|
def _get_extras(self):
|
||||||
if not ENABLE_OTEL:
|
if not ENABLE_OTEL:
|
||||||
|
|
@ -126,7 +132,6 @@ def start_logger():
|
||||||
format=stdout_format,
|
format=stdout_format,
|
||||||
filter=lambda record: "auditable" not in record["extra"],
|
filter=lambda record: "auditable" not in record["extra"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if AUDIT_LOG_LEVEL != "NONE":
|
if AUDIT_LOG_LEVEL != "NONE":
|
||||||
try:
|
try:
|
||||||
logger.add(
|
logger.add(
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ from open_webui.utils.filter import (
|
||||||
process_filter_functions,
|
process_filter_functions,
|
||||||
)
|
)
|
||||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||||
|
from open_webui.utils.payload import apply_model_system_prompt_to_body
|
||||||
|
|
||||||
from open_webui.tasks import create_task
|
from open_webui.tasks import create_task
|
||||||
|
|
||||||
|
|
@ -94,6 +95,7 @@ from open_webui.config import (
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
GLOBAL_LOG_LEVEL,
|
GLOBAL_LOG_LEVEL,
|
||||||
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
||||||
BYPASS_MODEL_ACCESS_CONTROL,
|
BYPASS_MODEL_ACCESS_CONTROL,
|
||||||
ENABLE_REALTIME_CHAT_SAVE,
|
ENABLE_REALTIME_CHAT_SAVE,
|
||||||
)
|
)
|
||||||
|
|
@ -683,6 +685,7 @@ def apply_params_to_form_data(form_data, model):
|
||||||
|
|
||||||
open_webui_params = {
|
open_webui_params = {
|
||||||
"stream_response": bool,
|
"stream_response": bool,
|
||||||
|
"stream_delta_chunk_size": int,
|
||||||
"function_calling": str,
|
"function_calling": str,
|
||||||
"system": str,
|
"system": str,
|
||||||
}
|
}
|
||||||
|
|
@ -774,8 +777,8 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
|
|
||||||
if folder and folder.data:
|
if folder and folder.data:
|
||||||
if "system_prompt" in folder.data:
|
if "system_prompt" in folder.data:
|
||||||
form_data["messages"] = add_or_update_system_message(
|
form_data = apply_model_system_prompt_to_body(
|
||||||
folder.data["system_prompt"], form_data["messages"]
|
folder.data["system_prompt"], form_data, metadata, user
|
||||||
)
|
)
|
||||||
if "files" in folder.data:
|
if "files" in folder.data:
|
||||||
form_data["files"] = [
|
form_data["files"] = [
|
||||||
|
|
@ -929,7 +932,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools_dict:
|
if tools_dict:
|
||||||
if metadata.get("function_calling") == "native":
|
if metadata.get("params", {}).get("function_calling") == "native":
|
||||||
# If the function calling is native, then call the tools function calling handler
|
# If the function calling is native, then call the tools function calling handler
|
||||||
metadata["tools"] = tools_dict
|
metadata["tools"] = tools_dict
|
||||||
form_data["tools"] = [
|
form_data["tools"] = [
|
||||||
|
|
@ -1381,14 +1384,6 @@ async def process_chat_response(
|
||||||
task_id = str(uuid4()) # Create a unique task ID.
|
task_id = str(uuid4()) # Create a unique task ID.
|
||||||
model_id = form_data.get("model", "")
|
model_id = form_data.get("model", "")
|
||||||
|
|
||||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
||||||
metadata["chat_id"],
|
|
||||||
metadata["message_id"],
|
|
||||||
{
|
|
||||||
"model": model_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def split_content_and_whitespace(content):
|
def split_content_and_whitespace(content):
|
||||||
content_stripped = content.rstrip()
|
content_stripped = content.rstrip()
|
||||||
original_whitespace = (
|
original_whitespace = (
|
||||||
|
|
@ -1410,13 +1405,18 @@ async def process_chat_response(
|
||||||
|
|
||||||
for block in content_blocks:
|
for block in content_blocks:
|
||||||
if block["type"] == "text":
|
if block["type"] == "text":
|
||||||
content = f"{content}{block['content'].strip()}\n"
|
block_content = block["content"].strip()
|
||||||
|
if block_content:
|
||||||
|
content = f"{content}{block_content}\n"
|
||||||
elif block["type"] == "tool_calls":
|
elif block["type"] == "tool_calls":
|
||||||
attributes = block.get("attributes", {})
|
attributes = block.get("attributes", {})
|
||||||
|
|
||||||
tool_calls = block.get("content", [])
|
tool_calls = block.get("content", [])
|
||||||
results = block.get("results", [])
|
results = block.get("results", [])
|
||||||
|
|
||||||
|
if content and not content.endswith("\n"):
|
||||||
|
content += "\n"
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
|
|
||||||
tool_calls_display_content = ""
|
tool_calls_display_content = ""
|
||||||
|
|
@ -1439,12 +1439,12 @@ async def process_chat_response(
|
||||||
break
|
break
|
||||||
|
|
||||||
if tool_result:
|
if tool_result:
|
||||||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result, ensure_ascii=False))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}">\n<summary>Tool Executed</summary>\n</details>\n'
|
tool_calls_display_content = f'{tool_calls_display_content}<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result, ensure_ascii=False))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}">\n<summary>Tool Executed</summary>\n</details>\n'
|
||||||
else:
|
else:
|
||||||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>'
|
tool_calls_display_content = f'{tool_calls_display_content}<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>\n'
|
||||||
|
|
||||||
if not raw:
|
if not raw:
|
||||||
content = f"{content}\n{tool_calls_display_content}\n\n"
|
content = f"{content}{tool_calls_display_content}"
|
||||||
else:
|
else:
|
||||||
tool_calls_display_content = ""
|
tool_calls_display_content = ""
|
||||||
|
|
||||||
|
|
@ -1457,10 +1457,10 @@ async def process_chat_response(
|
||||||
"arguments", ""
|
"arguments", ""
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>'
|
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>\n'
|
||||||
|
|
||||||
if not raw:
|
if not raw:
|
||||||
content = f"{content}\n{tool_calls_display_content}\n\n"
|
content = f"{content}{tool_calls_display_content}"
|
||||||
|
|
||||||
elif block["type"] == "reasoning":
|
elif block["type"] == "reasoning":
|
||||||
reasoning_display_content = "\n".join(
|
reasoning_display_content = "\n".join(
|
||||||
|
|
@ -1470,16 +1470,26 @@ async def process_chat_response(
|
||||||
|
|
||||||
reasoning_duration = block.get("duration", None)
|
reasoning_duration = block.get("duration", None)
|
||||||
|
|
||||||
|
start_tag = block.get("start_tag", "")
|
||||||
|
end_tag = block.get("end_tag", "")
|
||||||
|
|
||||||
|
if content and not content.endswith("\n"):
|
||||||
|
content += "\n"
|
||||||
|
|
||||||
if reasoning_duration is not None:
|
if reasoning_duration is not None:
|
||||||
if raw:
|
if raw:
|
||||||
content = f'{content}\n{block["start_tag"]}{block["content"]}{block["end_tag"]}\n'
|
content = (
|
||||||
|
f'{content}{start_tag}{block["content"]}{end_tag}\n'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||||
else:
|
else:
|
||||||
if raw:
|
if raw:
|
||||||
content = f'{content}\n{block["start_tag"]}{block["content"]}{block["end_tag"]}\n'
|
content = (
|
||||||
|
f'{content}{start_tag}{block["content"]}{end_tag}\n'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||||
|
|
||||||
elif block["type"] == "code_interpreter":
|
elif block["type"] == "code_interpreter":
|
||||||
attributes = block.get("attributes", {})
|
attributes = block.get("attributes", {})
|
||||||
|
|
@ -1499,26 +1509,30 @@ async def process_chat_response(
|
||||||
# Keep content as is - either closing backticks or no backticks
|
# Keep content as is - either closing backticks or no backticks
|
||||||
content = content_stripped + original_whitespace
|
content = content_stripped + original_whitespace
|
||||||
|
|
||||||
|
if content and not content.endswith("\n"):
|
||||||
|
content += "\n"
|
||||||
|
|
||||||
if output:
|
if output:
|
||||||
output = html.escape(json.dumps(output))
|
output = html.escape(json.dumps(output))
|
||||||
|
|
||||||
if raw:
|
if raw:
|
||||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
||||||
else:
|
else:
|
||||||
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
content = f'{content}<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||||
else:
|
else:
|
||||||
if raw:
|
if raw:
|
||||||
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
content = f'{content}<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
||||||
else:
|
else:
|
||||||
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
content = f'{content}<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
||||||
|
|
||||||
else:
|
else:
|
||||||
block_content = str(block["content"]).strip()
|
block_content = str(block["content"]).strip()
|
||||||
|
if block_content:
|
||||||
content = f"{content}{block['type']}: {block_content}\n"
|
content = f"{content}{block['type']}: {block_content}\n"
|
||||||
|
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
def convert_content_blocks_to_messages(content_blocks):
|
def convert_content_blocks_to_messages(content_blocks, raw=False):
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
temp_blocks = []
|
temp_blocks = []
|
||||||
|
|
@ -1527,7 +1541,7 @@ async def process_chat_response(
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": serialize_content_blocks(temp_blocks),
|
"content": serialize_content_blocks(temp_blocks, raw),
|
||||||
"tool_calls": block.get("content"),
|
"tool_calls": block.get("content"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -1547,7 +1561,7 @@ async def process_chat_response(
|
||||||
temp_blocks.append(block)
|
temp_blocks.append(block)
|
||||||
|
|
||||||
if temp_blocks:
|
if temp_blocks:
|
||||||
content = serialize_content_blocks(temp_blocks)
|
content = serialize_content_blocks(temp_blocks, raw)
|
||||||
if content:
|
if content:
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
|
|
@ -1804,6 +1818,15 @@ async def process_chat_response(
|
||||||
|
|
||||||
response_tool_calls = []
|
response_tool_calls = []
|
||||||
|
|
||||||
|
delta_count = 0
|
||||||
|
delta_chunk_size = max(
|
||||||
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
||||||
|
int(
|
||||||
|
metadata.get("params", {}).get("stream_delta_chunk_size")
|
||||||
|
or 1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async for line in response.body_iterator:
|
async for line in response.body_iterator:
|
||||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||||
data = line
|
data = line
|
||||||
|
|
@ -1943,8 +1966,8 @@ async def process_chat_response(
|
||||||
):
|
):
|
||||||
reasoning_block = {
|
reasoning_block = {
|
||||||
"type": "reasoning",
|
"type": "reasoning",
|
||||||
"start_tag": "think",
|
"start_tag": "<think>",
|
||||||
"end_tag": "/think",
|
"end_tag": "</think>",
|
||||||
"attributes": {
|
"attributes": {
|
||||||
"type": "reasoning_content"
|
"type": "reasoning_content"
|
||||||
},
|
},
|
||||||
|
|
@ -2051,6 +2074,17 @@ async def process_chat_response(
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
delta_count += 1
|
||||||
|
if delta_count >= delta_chunk_size:
|
||||||
|
await event_emitter(
|
||||||
|
{
|
||||||
|
"type": "chat:completion",
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
delta_count = 0
|
||||||
|
else:
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
"type": "chat:completion",
|
"type": "chat:completion",
|
||||||
|
|
@ -2083,6 +2117,15 @@ async def process_chat_response(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if content_blocks[-1]["type"] == "reasoning":
|
||||||
|
reasoning_block = content_blocks[-1]
|
||||||
|
if reasoning_block.get("ended_at") is None:
|
||||||
|
reasoning_block["ended_at"] = time.time()
|
||||||
|
reasoning_block["duration"] = int(
|
||||||
|
reasoning_block["ended_at"]
|
||||||
|
- reasoning_block["started_at"]
|
||||||
|
)
|
||||||
|
|
||||||
if response_tool_calls:
|
if response_tool_calls:
|
||||||
tool_calls.append(response_tool_calls)
|
tool_calls.append(response_tool_calls)
|
||||||
|
|
||||||
|
|
@ -2095,6 +2138,7 @@ async def process_chat_response(
|
||||||
tool_call_retries = 0
|
tool_call_retries = 0
|
||||||
|
|
||||||
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
||||||
|
|
||||||
tool_call_retries += 1
|
tool_call_retries += 1
|
||||||
|
|
||||||
response_tool_calls = tool_calls.pop(0)
|
response_tool_calls = tool_calls.pop(0)
|
||||||
|
|
@ -2246,7 +2290,9 @@ async def process_chat_response(
|
||||||
"tools": form_data["tools"],
|
"tools": form_data["tools"],
|
||||||
"messages": [
|
"messages": [
|
||||||
*form_data["messages"],
|
*form_data["messages"],
|
||||||
*convert_content_blocks_to_messages(content_blocks),
|
*convert_content_blocks_to_messages(
|
||||||
|
content_blocks, True
|
||||||
|
),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,7 @@ def openai_chat_chunk_message_template(
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
||||||
|
|
||||||
if not content and not tool_calls:
|
if not content and not reasoning_content and not tool_calls:
|
||||||
template["choices"][0]["finish_reason"] = "stop"
|
template["choices"][0]["finish_reason"] = "stop"
|
||||||
|
|
||||||
if usage:
|
if usage:
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from open_webui.config import (
|
||||||
ENABLE_OAUTH_GROUP_CREATION,
|
ENABLE_OAUTH_GROUP_CREATION,
|
||||||
OAUTH_BLOCKED_GROUPS,
|
OAUTH_BLOCKED_GROUPS,
|
||||||
OAUTH_ROLES_CLAIM,
|
OAUTH_ROLES_CLAIM,
|
||||||
|
OAUTH_SUB_CLAIM,
|
||||||
OAUTH_GROUPS_CLAIM,
|
OAUTH_GROUPS_CLAIM,
|
||||||
OAUTH_EMAIL_CLAIM,
|
OAUTH_EMAIL_CLAIM,
|
||||||
OAUTH_PICTURE_CLAIM,
|
OAUTH_PICTURE_CLAIM,
|
||||||
|
|
@ -65,6 +66,7 @@ auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMEN
|
||||||
auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
|
auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
|
||||||
auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
|
auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
|
||||||
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||||
|
auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM
|
||||||
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
|
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
|
||||||
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||||
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||||
|
|
@ -88,11 +90,12 @@ class OAuthManager:
|
||||||
return self.oauth.create_client(provider_name)
|
return self.oauth.create_client(provider_name)
|
||||||
|
|
||||||
def get_user_role(self, user, user_data):
|
def get_user_role(self, user, user_data):
|
||||||
if user and Users.get_num_users() == 1:
|
user_count = Users.get_num_users()
|
||||||
|
if user and user_count == 1:
|
||||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||||
log.debug("Assigning the only user the admin role")
|
log.debug("Assigning the only user the admin role")
|
||||||
return "admin"
|
return "admin"
|
||||||
if not user and Users.get_num_users() == 0:
|
if not user and user_count == 0:
|
||||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||||
log.debug("Assigning the first user the admin role")
|
log.debug("Assigning the first user the admin role")
|
||||||
return "admin"
|
return "admin"
|
||||||
|
|
@ -358,11 +361,18 @@ class OAuthManager:
|
||||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
if auth_manager_config.OAUTH_SUB_CLAIM:
|
||||||
|
sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
||||||
|
else:
|
||||||
|
# Fallback to the default sub claim if not configured
|
||||||
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
||||||
|
|
||||||
if not sub:
|
if not sub:
|
||||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
provider_sub = f"{provider}@{sub}"
|
provider_sub = f"{provider}@{sub}"
|
||||||
|
|
||||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||||
email = user_data.get(email_claim, "")
|
email = user_data.get(email_claim, "")
|
||||||
# We currently mandate that email addresses are provided
|
# We currently mandate that email addresses are provided
|
||||||
|
|
@ -449,8 +459,6 @@ class OAuthManager:
|
||||||
log.debug(f"Updated profile picture for user {user.email}")
|
log.debug(f"Updated profile picture for user {user.email}")
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
user_count = Users.get_num_users()
|
|
||||||
|
|
||||||
# If the user does not exist, check if signups are enabled
|
# If the user does not exist, check if signups are enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||||
# Check if an existing user with the same email already exists
|
# Check if an existing user with the same email already exists
|
||||||
|
|
@ -521,7 +529,7 @@ class OAuthManager:
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="token",
|
key="token",
|
||||||
value=jwt_token,
|
value=jwt_token,
|
||||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
httponly=False, # Required for frontend access
|
||||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
@ -540,6 +548,6 @@ class OAuthManager:
|
||||||
redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
|
redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
|
||||||
if redirect_base_url.endswith("/"):
|
if redirect_base_url.endswith("/"):
|
||||||
redirect_base_url = redirect_base_url[:-1]
|
redirect_base_url = redirect_base_url[:-1]
|
||||||
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
|
redirect_url = f"{redirect_base_url}/auth"
|
||||||
|
|
||||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ def remove_open_webui_params(params: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
open_webui_params = {
|
open_webui_params = {
|
||||||
"stream_response": bool,
|
"stream_response": bool,
|
||||||
|
"stream_delta_chunk_size": int,
|
||||||
"function_calling": str,
|
"function_calling": str,
|
||||||
"system": str,
|
"system": str,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,9 @@ from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_CONNECTION_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
class SentinelRedisProxy:
|
class SentinelRedisProxy:
|
||||||
def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
|
def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
|
||||||
self._sentinel = sentinel
|
self._sentinel = sentinel
|
||||||
|
|
@ -93,8 +96,8 @@ class SentinelRedisProxy:
|
||||||
|
|
||||||
def parse_redis_service_url(redis_url):
|
def parse_redis_service_url(redis_url):
|
||||||
parsed_url = urlparse(redis_url)
|
parsed_url = urlparse(redis_url)
|
||||||
if parsed_url.scheme != "redis":
|
if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
|
||||||
raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
|
raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"username": parsed_url.username or None,
|
"username": parsed_url.username or None,
|
||||||
|
|
@ -106,8 +109,25 @@ def parse_redis_service_url(redis_url):
|
||||||
|
|
||||||
|
|
||||||
def get_redis_connection(
|
def get_redis_connection(
|
||||||
redis_url, redis_sentinels, async_mode=False, decode_responses=True
|
redis_url,
|
||||||
|
redis_sentinels,
|
||||||
|
redis_cluster=False,
|
||||||
|
async_mode=False,
|
||||||
|
decode_responses=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
cache_key = (
|
||||||
|
redis_url,
|
||||||
|
tuple(redis_sentinels) if redis_sentinels else (),
|
||||||
|
async_mode,
|
||||||
|
decode_responses,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_key in _CONNECTION_CACHE:
|
||||||
|
return _CONNECTION_CACHE[cache_key]
|
||||||
|
|
||||||
|
connection = None
|
||||||
|
|
||||||
if async_mode:
|
if async_mode:
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
|
@ -122,15 +142,19 @@ def get_redis_connection(
|
||||||
password=redis_config["password"],
|
password=redis_config["password"],
|
||||||
decode_responses=decode_responses,
|
decode_responses=decode_responses,
|
||||||
)
|
)
|
||||||
return SentinelRedisProxy(
|
connection = SentinelRedisProxy(
|
||||||
sentinel,
|
sentinel,
|
||||||
redis_config["service"],
|
redis_config["service"],
|
||||||
async_mode=async_mode,
|
async_mode=async_mode,
|
||||||
)
|
)
|
||||||
|
elif redis_cluster:
|
||||||
|
if not redis_url:
|
||||||
|
raise ValueError("Redis URL must be provided for cluster mode.")
|
||||||
|
return redis.cluster.RedisCluster.from_url(
|
||||||
|
redis_url, decode_responses=decode_responses
|
||||||
|
)
|
||||||
elif redis_url:
|
elif redis_url:
|
||||||
return redis.from_url(redis_url, decode_responses=decode_responses)
|
connection = redis.from_url(redis_url, decode_responses=decode_responses)
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
else:
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
|
|
@ -144,15 +168,24 @@ def get_redis_connection(
|
||||||
password=redis_config["password"],
|
password=redis_config["password"],
|
||||||
decode_responses=decode_responses,
|
decode_responses=decode_responses,
|
||||||
)
|
)
|
||||||
return SentinelRedisProxy(
|
connection = SentinelRedisProxy(
|
||||||
sentinel,
|
sentinel,
|
||||||
redis_config["service"],
|
redis_config["service"],
|
||||||
async_mode=async_mode,
|
async_mode=async_mode,
|
||||||
)
|
)
|
||||||
|
elif redis_cluster:
|
||||||
|
if not redis_url:
|
||||||
|
raise ValueError("Redis URL must be provided for cluster mode.")
|
||||||
|
return redis.cluster.RedisCluster.from_url(
|
||||||
|
redis_url, decode_responses=decode_responses
|
||||||
|
)
|
||||||
elif redis_url:
|
elif redis_url:
|
||||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
connection = redis.Redis.from_url(
|
||||||
else:
|
redis_url, decode_responses=decode_responses
|
||||||
return None
|
)
|
||||||
|
|
||||||
|
_CONNECTION_CACHE[cache_key] = connection
|
||||||
|
return connection
|
||||||
|
|
||||||
|
|
||||||
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
||||||
|
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
import threading
|
|
||||||
|
|
||||||
from opentelemetry.sdk.trace import ReadableSpan
|
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class LazyBatchSpanProcessor(BatchSpanProcessor):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.done = True
|
|
||||||
with self.condition:
|
|
||||||
self.condition.notify_all()
|
|
||||||
self.worker_thread.join()
|
|
||||||
self.done = False
|
|
||||||
self.worker_thread = None
|
|
||||||
|
|
||||||
def on_end(self, span: ReadableSpan) -> None:
|
|
||||||
if self.worker_thread is None:
|
|
||||||
self.worker_thread = threading.Thread(
|
|
||||||
name=self.__class__.__name__, target=self.worker, daemon=True
|
|
||||||
)
|
|
||||||
self.worker_thread.start()
|
|
||||||
super().on_end(span)
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
|
||||||
self.done = True
|
|
||||||
with self.condition:
|
|
||||||
self.condition.notify_all()
|
|
||||||
if self.worker_thread:
|
|
||||||
self.worker_thread.join()
|
|
||||||
self.span_exporter.shutdown()
|
|
||||||
53
backend/open_webui/utils/telemetry/logs.py
Normal file
53
backend/open_webui/utils/telemetry/logs.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
import logging
|
||||||
|
from base64 import b64encode
|
||||||
|
from opentelemetry.sdk._logs import (
|
||||||
|
LoggingHandler,
|
||||||
|
LoggerProvider,
|
||||||
|
)
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter
|
||||||
|
from opentelemetry.exporter.otlp.proto.http._log_exporter import (
|
||||||
|
OTLPLogExporter as HttpOTLPLogExporter,
|
||||||
|
)
|
||||||
|
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
|
||||||
|
from opentelemetry._logs import set_logger_provider
|
||||||
|
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||||
|
from open_webui.env import (
|
||||||
|
OTEL_SERVICE_NAME,
|
||||||
|
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT,
|
||||||
|
OTEL_LOGS_EXPORTER_OTLP_INSECURE,
|
||||||
|
OTEL_LOGS_BASIC_AUTH_USERNAME,
|
||||||
|
OTEL_LOGS_BASIC_AUTH_PASSWORD,
|
||||||
|
OTEL_LOGS_OTLP_SPAN_EXPORTER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
headers = []
|
||||||
|
if OTEL_LOGS_BASIC_AUTH_USERNAME and OTEL_LOGS_BASIC_AUTH_PASSWORD:
|
||||||
|
auth_string = f"{OTEL_LOGS_BASIC_AUTH_USERNAME}:{OTEL_LOGS_BASIC_AUTH_PASSWORD}"
|
||||||
|
auth_header = b64encode(auth_string.encode()).decode()
|
||||||
|
headers = [("authorization", f"Basic {auth_header}")]
|
||||||
|
resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
||||||
|
|
||||||
|
if OTEL_LOGS_OTLP_SPAN_EXPORTER == "http":
|
||||||
|
exporter = HttpOTLPLogExporter(
|
||||||
|
endpoint=OTEL_LOGS_EXPORTER_OTLP_ENDPOINT,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
exporter = OTLPLogExporter(
|
||||||
|
endpoint=OTEL_LOGS_EXPORTER_OTLP_ENDPOINT,
|
||||||
|
insecure=OTEL_LOGS_EXPORTER_OTLP_INSECURE,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
logger_provider = LoggerProvider(resource=resource)
|
||||||
|
set_logger_provider(logger_provider)
|
||||||
|
|
||||||
|
logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter))
|
||||||
|
|
||||||
|
otel_handler = LoggingHandler(logger_provider=logger_provider)
|
||||||
|
|
||||||
|
return otel_handler
|
||||||
|
|
||||||
|
|
||||||
|
otel_handler = setup_logging()
|
||||||
|
|
@ -19,34 +19,66 @@ from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Sequence, Any
|
from typing import Dict, List, Sequence, Any
|
||||||
|
from base64 import b64encode
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from opentelemetry import metrics
|
from opentelemetry import metrics
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
|
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
|
||||||
OTLPMetricExporter,
|
OTLPMetricExporter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||||
|
OTLPMetricExporter as OTLPHttpMetricExporter,
|
||||||
|
)
|
||||||
from opentelemetry.sdk.metrics import MeterProvider
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
from opentelemetry.sdk.metrics.view import View
|
from opentelemetry.sdk.metrics.view import View
|
||||||
from opentelemetry.sdk.metrics.export import (
|
from opentelemetry.sdk.metrics.export import (
|
||||||
PeriodicExportingMetricReader,
|
PeriodicExportingMetricReader,
|
||||||
)
|
)
|
||||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
from opentelemetry.sdk.resources import Resource
|
||||||
|
|
||||||
from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT
|
|
||||||
|
|
||||||
|
from open_webui.env import (
|
||||||
|
OTEL_SERVICE_NAME,
|
||||||
|
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT,
|
||||||
|
OTEL_METRICS_BASIC_AUTH_USERNAME,
|
||||||
|
OTEL_METRICS_BASIC_AUTH_PASSWORD,
|
||||||
|
OTEL_METRICS_OTLP_SPAN_EXPORTER,
|
||||||
|
OTEL_METRICS_EXPORTER_OTLP_INSECURE,
|
||||||
|
)
|
||||||
from open_webui.socket.main import get_active_user_ids
|
from open_webui.socket.main import get_active_user_ids
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
|
|
||||||
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
|
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
|
||||||
|
|
||||||
|
|
||||||
def _build_meter_provider() -> MeterProvider:
|
def _build_meter_provider(resource: Resource) -> MeterProvider:
|
||||||
"""Return a configured MeterProvider."""
|
"""Return a configured MeterProvider."""
|
||||||
|
headers = []
|
||||||
|
if OTEL_METRICS_BASIC_AUTH_USERNAME and OTEL_METRICS_BASIC_AUTH_PASSWORD:
|
||||||
|
auth_string = (
|
||||||
|
f"{OTEL_METRICS_BASIC_AUTH_USERNAME}:{OTEL_METRICS_BASIC_AUTH_PASSWORD}"
|
||||||
|
)
|
||||||
|
auth_header = b64encode(auth_string.encode()).decode()
|
||||||
|
headers = [("authorization", f"Basic {auth_header}")]
|
||||||
|
|
||||||
# Periodic reader pushes metrics over OTLP/gRPC to collector
|
# Periodic reader pushes metrics over OTLP/gRPC to collector
|
||||||
|
if OTEL_METRICS_OTLP_SPAN_EXPORTER == "http":
|
||||||
readers: List[PeriodicExportingMetricReader] = [
|
readers: List[PeriodicExportingMetricReader] = [
|
||||||
PeriodicExportingMetricReader(
|
PeriodicExportingMetricReader(
|
||||||
OTLPMetricExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT),
|
OTLPHttpMetricExporter(
|
||||||
|
endpoint=OTEL_METRICS_EXPORTER_OTLP_ENDPOINT, headers=headers
|
||||||
|
),
|
||||||
|
export_interval_millis=_EXPORT_INTERVAL_MILLIS,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
readers: List[PeriodicExportingMetricReader] = [
|
||||||
|
PeriodicExportingMetricReader(
|
||||||
|
OTLPMetricExporter(
|
||||||
|
endpoint=OTEL_METRICS_EXPORTER_OTLP_ENDPOINT,
|
||||||
|
insecure=OTEL_METRICS_EXPORTER_OTLP_INSECURE,
|
||||||
|
headers=headers,
|
||||||
|
),
|
||||||
export_interval_millis=_EXPORT_INTERVAL_MILLIS,
|
export_interval_millis=_EXPORT_INTERVAL_MILLIS,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -70,17 +102,17 @@ def _build_meter_provider() -> MeterProvider:
|
||||||
]
|
]
|
||||||
|
|
||||||
provider = MeterProvider(
|
provider = MeterProvider(
|
||||||
resource=Resource.create({SERVICE_NAME: OTEL_SERVICE_NAME}),
|
resource=resource,
|
||||||
metric_readers=list(readers),
|
metric_readers=list(readers),
|
||||||
views=views,
|
views=views,
|
||||||
)
|
)
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
def setup_metrics(app: FastAPI) -> None:
|
def setup_metrics(app: FastAPI, resource: Resource) -> None:
|
||||||
"""Attach OTel metrics middleware to *app* and initialise provider."""
|
"""Attach OTel metrics middleware to *app* and initialise provider."""
|
||||||
|
|
||||||
metrics.set_meter_provider(_build_meter_provider())
|
metrics.set_meter_provider(_build_meter_provider(resource))
|
||||||
meter = metrics.get_meter(__name__)
|
meter = metrics.get_meter(__name__)
|
||||||
|
|
||||||
# Instruments
|
# Instruments
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,16 @@
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||||
OTLPSpanExporter as HttpOTLPSpanExporter,
|
OTLPSpanExporter as HttpOTLPSpanExporter,
|
||||||
)
|
)
|
||||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
|
||||||
from open_webui.utils.telemetry.exporters import LazyBatchSpanProcessor
|
|
||||||
from open_webui.utils.telemetry.instrumentors import Instrumentor
|
from open_webui.utils.telemetry.instrumentors import Instrumentor
|
||||||
from open_webui.utils.telemetry.metrics import setup_metrics
|
from open_webui.utils.telemetry.metrics import setup_metrics
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
|
@ -25,11 +26,8 @@ from open_webui.env import (
|
||||||
|
|
||||||
def setup(app: FastAPI, db_engine: Engine):
|
def setup(app: FastAPI, db_engine: Engine):
|
||||||
# set up trace
|
# set up trace
|
||||||
trace.set_tracer_provider(
|
|
||||||
TracerProvider(
|
|
||||||
resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
||||||
)
|
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||||
)
|
|
||||||
|
|
||||||
# Add basic auth header only if both username and password are not empty
|
# Add basic auth header only if both username and password are not empty
|
||||||
headers = []
|
headers = []
|
||||||
|
|
@ -42,7 +40,6 @@ def setup(app: FastAPI, db_engine: Engine):
|
||||||
if OTEL_OTLP_SPAN_EXPORTER == "http":
|
if OTEL_OTLP_SPAN_EXPORTER == "http":
|
||||||
exporter = HttpOTLPSpanExporter(
|
exporter = HttpOTLPSpanExporter(
|
||||||
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
||||||
insecure=OTEL_EXPORTER_OTLP_INSECURE,
|
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -51,9 +48,9 @@ def setup(app: FastAPI, db_engine: Engine):
|
||||||
insecure=OTEL_EXPORTER_OTLP_INSECURE,
|
insecure=OTEL_EXPORTER_OTLP_INSECURE,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
|
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(exporter))
|
||||||
Instrumentor(app=app, db_engine=db_engine).instrument()
|
Instrumentor(app=app, db_engine=db_engine).instrument()
|
||||||
|
|
||||||
# set up metrics only if enabled
|
# set up metrics only if enabled
|
||||||
if ENABLE_OTEL_METRICS:
|
if ENABLE_OTEL_METRICS:
|
||||||
setup_metrics(app)
|
setup_metrics(app, resource)
|
||||||
|
|
|
||||||
|
|
@ -377,7 +377,6 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
||||||
for method, operation in methods.items():
|
for method, operation in methods.items():
|
||||||
if operation.get("operationId"):
|
if operation.get("operationId"):
|
||||||
tool = {
|
tool = {
|
||||||
"type": "function",
|
|
||||||
"name": operation.get("operationId"),
|
"name": operation.get("operationId"),
|
||||||
"description": operation.get(
|
"description": operation.get(
|
||||||
"description",
|
"description",
|
||||||
|
|
@ -399,10 +398,16 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
||||||
description += (
|
description += (
|
||||||
f". Possible values: {', '.join(param_schema.get('enum'))}"
|
f". Possible values: {', '.join(param_schema.get('enum'))}"
|
||||||
)
|
)
|
||||||
tool["parameters"]["properties"][param_name] = {
|
param_property = {
|
||||||
"type": param_schema.get("type"),
|
"type": param_schema.get("type"),
|
||||||
"description": description,
|
"description": description,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Include items property for array types (required by OpenAI)
|
||||||
|
if param_schema.get("type") == "array" and "items" in param_schema:
|
||||||
|
param_property["items"] = param_schema["items"]
|
||||||
|
|
||||||
|
tool["parameters"]["properties"][param_name] = param_property
|
||||||
if param.get("required"):
|
if param.get("required"):
|
||||||
tool["parameters"]["required"].append(param_name)
|
tool["parameters"]["required"].append(param_name)
|
||||||
|
|
||||||
|
|
@ -489,15 +494,7 @@ async def get_tool_servers_data(
|
||||||
if server.get("config", {}).get("enable"):
|
if server.get("config", {}).get("enable"):
|
||||||
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
|
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
|
||||||
openapi_path = server.get("path", "openapi.json")
|
openapi_path = server.get("path", "openapi.json")
|
||||||
if "://" in openapi_path:
|
full_url = get_tool_server_url(server.get("url"), openapi_path)
|
||||||
# If it contains "://", it's a full URL
|
|
||||||
full_url = openapi_path
|
|
||||||
else:
|
|
||||||
if not openapi_path.startswith("/"):
|
|
||||||
# Ensure the path starts with a slash
|
|
||||||
openapi_path = f"/{openapi_path}"
|
|
||||||
|
|
||||||
full_url = f"{server.get('url')}{openapi_path}"
|
|
||||||
|
|
||||||
info = server.get("info", {})
|
info = server.get("info", {})
|
||||||
|
|
||||||
|
|
@ -528,6 +525,8 @@ async def get_tool_servers_data(
|
||||||
openapi_data = response.get("openapi", {})
|
openapi_data = response.get("openapi", {})
|
||||||
|
|
||||||
if info and isinstance(openapi_data, dict):
|
if info and isinstance(openapi_data, dict):
|
||||||
|
openapi_data["info"] = openapi_data.get("info", {})
|
||||||
|
|
||||||
if "name" in info:
|
if "name" in info:
|
||||||
openapi_data["info"]["title"] = info.get("name", "Tool Server")
|
openapi_data["info"]["title"] = info.get("name", "Tool Server")
|
||||||
|
|
||||||
|
|
@ -643,3 +642,16 @@ async def execute_tool_server(
|
||||||
error = str(err)
|
error = str(err)
|
||||||
log.exception(f"API Request Error: {error}")
|
log.exception(f"API Request Error: {error}")
|
||||||
return {"error": error}
|
return {"error": error}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_server_url(url: Optional[str], path: str) -> str:
|
||||||
|
"""
|
||||||
|
Build the full URL for a tool server, given a base url and a path.
|
||||||
|
"""
|
||||||
|
if "://" in path:
|
||||||
|
# If it contains "://", it's a full URL
|
||||||
|
return path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
# Ensure the path starts with a slash
|
||||||
|
path = f"/{path}"
|
||||||
|
return f"{url}{path}"
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ passlib[bcrypt]==1.7.4
|
||||||
cryptography
|
cryptography
|
||||||
|
|
||||||
requests==2.32.4
|
requests==2.32.4
|
||||||
aiohttp==3.11.11
|
aiohttp==3.12.15
|
||||||
async-timeout
|
async-timeout
|
||||||
aiocache
|
aiocache
|
||||||
aiofiles
|
aiofiles
|
||||||
|
|
@ -27,7 +27,7 @@ bcrypt==4.3.0
|
||||||
|
|
||||||
pymongo
|
pymongo
|
||||||
redis
|
redis
|
||||||
boto3==1.35.53
|
boto3==1.40.5
|
||||||
|
|
||||||
argon2-cffi==23.1.0
|
argon2-cffi==23.1.0
|
||||||
APScheduler==3.10.4
|
APScheduler==3.10.4
|
||||||
|
|
@ -42,14 +42,14 @@ asgiref==3.8.1
|
||||||
# AI libraries
|
# AI libraries
|
||||||
openai
|
openai
|
||||||
anthropic
|
anthropic
|
||||||
google-genai==1.15.0
|
google-genai==1.28.0
|
||||||
google-generativeai==0.8.5
|
google-generativeai==0.8.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
langchain==0.3.26
|
langchain==0.3.26
|
||||||
langchain-community==0.3.26
|
langchain-community==0.3.26
|
||||||
|
|
||||||
fake-useragent==2.1.0
|
fake-useragent==2.2.0
|
||||||
chromadb==0.6.3
|
chromadb==0.6.3
|
||||||
posthog==5.4.0
|
posthog==5.4.0
|
||||||
pymilvus==2.5.0
|
pymilvus==2.5.0
|
||||||
|
|
@ -58,11 +58,14 @@ opensearch-py==2.8.0
|
||||||
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
||||||
elasticsearch==9.0.1
|
elasticsearch==9.0.1
|
||||||
pinecone==6.0.2
|
pinecone==6.0.2
|
||||||
|
oracledb==3.2.0
|
||||||
|
|
||||||
|
av==14.0.1 # Caution: Set due to FATAL FIPS SELFTEST FAILURE, see discussion https://github.com/open-webui/open-webui/discussions/15720
|
||||||
transformers
|
transformers
|
||||||
sentence-transformers==4.1.0
|
sentence-transformers==4.1.0
|
||||||
accelerate
|
accelerate
|
||||||
colbert-ai==0.2.21
|
colbert-ai==0.2.21
|
||||||
|
pyarrow==20.0.0
|
||||||
einops==0.8.1
|
einops==0.8.1
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -74,7 +77,7 @@ docx2txt==0.8
|
||||||
python-pptx==1.0.2
|
python-pptx==1.0.2
|
||||||
unstructured==0.16.17
|
unstructured==0.16.17
|
||||||
nltk==3.9.1
|
nltk==3.9.1
|
||||||
Markdown==3.7
|
Markdown==3.8.2
|
||||||
pypandoc==1.15
|
pypandoc==1.15
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
openpyxl==3.1.5
|
openpyxl==3.1.5
|
||||||
|
|
@ -86,7 +89,7 @@ sentencepiece
|
||||||
soundfile==0.13.1
|
soundfile==0.13.1
|
||||||
azure-ai-documentintelligence==1.0.2
|
azure-ai-documentintelligence==1.0.2
|
||||||
|
|
||||||
pillow==11.2.1
|
pillow==11.3.0
|
||||||
opencv-python-headless==4.11.0.86
|
opencv-python-headless==4.11.0.86
|
||||||
rapidocr-onnxruntime==1.4.4
|
rapidocr-onnxruntime==1.4.4
|
||||||
rank-bm25==0.2.2
|
rank-bm25==0.2.2
|
||||||
|
|
@ -96,7 +99,7 @@ onnxruntime==1.20.1
|
||||||
faster-whisper==1.1.1
|
faster-whisper==1.1.1
|
||||||
|
|
||||||
PyJWT[crypto]==2.10.1
|
PyJWT[crypto]==2.10.1
|
||||||
authlib==1.4.1
|
authlib==1.6.1
|
||||||
|
|
||||||
black==25.1.0
|
black==25.1.0
|
||||||
langfuse==2.44.0
|
langfuse==2.44.0
|
||||||
|
|
@ -133,14 +136,14 @@ firecrawl-py==1.12.0
|
||||||
tencentcloud-sdk-python==3.0.1336
|
tencentcloud-sdk-python==3.0.1336
|
||||||
|
|
||||||
## Trace
|
## Trace
|
||||||
opentelemetry-api==1.32.1
|
opentelemetry-api==1.36.0
|
||||||
opentelemetry-sdk==1.32.1
|
opentelemetry-sdk==1.36.0
|
||||||
opentelemetry-exporter-otlp==1.32.1
|
opentelemetry-exporter-otlp==1.36.0
|
||||||
opentelemetry-instrumentation==0.53b1
|
opentelemetry-instrumentation==0.57b0
|
||||||
opentelemetry-instrumentation-fastapi==0.53b1
|
opentelemetry-instrumentation-fastapi==0.57b0
|
||||||
opentelemetry-instrumentation-sqlalchemy==0.53b1
|
opentelemetry-instrumentation-sqlalchemy==0.57b0
|
||||||
opentelemetry-instrumentation-redis==0.53b1
|
opentelemetry-instrumentation-redis==0.57b0
|
||||||
opentelemetry-instrumentation-requests==0.53b1
|
opentelemetry-instrumentation-requests==0.57b0
|
||||||
opentelemetry-instrumentation-logging==0.53b1
|
opentelemetry-instrumentation-logging==0.57b0
|
||||||
opentelemetry-instrumentation-httpx==0.53b1
|
opentelemetry-instrumentation-httpx==0.57b0
|
||||||
opentelemetry-instrumentation-aiohttp-client==0.53b1
|
opentelemetry-instrumentation-aiohttp-client==0.57b0
|
||||||
|
|
|
||||||
758
package-lock.json
generated
758
package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "open-webui",
|
"name": "open-webui",
|
||||||
"version": "0.6.18",
|
"version": "0.6.19",
|
||||||
"private": true,
|
"private": true,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "npm run pyodide:fetch && vite dev --host",
|
"dev": "npm run pyodide:fetch && vite dev --host",
|
||||||
|
|
@ -59,6 +59,7 @@
|
||||||
"@codemirror/theme-one-dark": "^6.1.2",
|
"@codemirror/theme-one-dark": "^6.1.2",
|
||||||
"@floating-ui/dom": "^1.7.2",
|
"@floating-ui/dom": "^1.7.2",
|
||||||
"@huggingface/transformers": "^3.0.0",
|
"@huggingface/transformers": "^3.0.0",
|
||||||
|
"@joplin/turndown-plugin-gfm": "^1.0.62",
|
||||||
"@mediapipe/tasks-vision": "^0.10.17",
|
"@mediapipe/tasks-vision": "^0.10.17",
|
||||||
"@pyscript/core": "^0.4.32",
|
"@pyscript/core": "^0.4.32",
|
||||||
"@sveltejs/adapter-node": "^2.0.0",
|
"@sveltejs/adapter-node": "^2.0.0",
|
||||||
|
|
@ -73,6 +74,7 @@
|
||||||
"@tiptap/extension-image": "^3.0.7",
|
"@tiptap/extension-image": "^3.0.7",
|
||||||
"@tiptap/extension-link": "^3.0.7",
|
"@tiptap/extension-link": "^3.0.7",
|
||||||
"@tiptap/extension-list": "^3.0.7",
|
"@tiptap/extension-list": "^3.0.7",
|
||||||
|
"@tiptap/extension-mention": "^3.0.9",
|
||||||
"@tiptap/extension-table": "^3.0.7",
|
"@tiptap/extension-table": "^3.0.7",
|
||||||
"@tiptap/extension-typography": "^3.0.7",
|
"@tiptap/extension-typography": "^3.0.7",
|
||||||
"@tiptap/extension-youtube": "^3.0.7",
|
"@tiptap/extension-youtube": "^3.0.7",
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ dependencies = [
|
||||||
"cryptography",
|
"cryptography",
|
||||||
|
|
||||||
"requests==2.32.4",
|
"requests==2.32.4",
|
||||||
"aiohttp==3.11.11",
|
"aiohttp==3.12.15",
|
||||||
"async-timeout",
|
"async-timeout",
|
||||||
"aiocache",
|
"aiocache",
|
||||||
"aiofiles",
|
"aiofiles",
|
||||||
|
|
@ -35,7 +35,7 @@ dependencies = [
|
||||||
|
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"redis",
|
"redis",
|
||||||
"boto3==1.35.53",
|
"boto3==1.40.5",
|
||||||
|
|
||||||
"argon2-cffi==23.1.0",
|
"argon2-cffi==23.1.0",
|
||||||
"APScheduler==3.10.4",
|
"APScheduler==3.10.4",
|
||||||
|
|
@ -50,14 +50,14 @@ dependencies = [
|
||||||
|
|
||||||
"openai",
|
"openai",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"google-genai==1.15.0",
|
"google-genai==1.28.0",
|
||||||
"google-generativeai==0.8.5",
|
"google-generativeai==0.8.5",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
|
|
||||||
"langchain==0.3.26",
|
"langchain==0.3.26",
|
||||||
"langchain-community==0.3.26",
|
"langchain-community==0.3.26",
|
||||||
|
|
||||||
"fake-useragent==2.1.0",
|
"fake-useragent==2.2.0",
|
||||||
"chromadb==0.6.3",
|
"chromadb==0.6.3",
|
||||||
"pymilvus==2.5.0",
|
"pymilvus==2.5.0",
|
||||||
"qdrant-client==1.14.3",
|
"qdrant-client==1.14.3",
|
||||||
|
|
@ -65,11 +65,13 @@ dependencies = [
|
||||||
"playwright==1.49.1",
|
"playwright==1.49.1",
|
||||||
"elasticsearch==9.0.1",
|
"elasticsearch==9.0.1",
|
||||||
"pinecone==6.0.2",
|
"pinecone==6.0.2",
|
||||||
|
"oracledb==3.2.0",
|
||||||
|
|
||||||
"transformers",
|
"transformers",
|
||||||
"sentence-transformers==4.1.0",
|
"sentence-transformers==4.1.0",
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"colbert-ai==0.2.21",
|
"colbert-ai==0.2.21",
|
||||||
|
"pyarrow==20.0.0",
|
||||||
"einops==0.8.1",
|
"einops==0.8.1",
|
||||||
|
|
||||||
"ftfy==6.2.3",
|
"ftfy==6.2.3",
|
||||||
|
|
@ -80,7 +82,7 @@ dependencies = [
|
||||||
"python-pptx==1.0.2",
|
"python-pptx==1.0.2",
|
||||||
"unstructured==0.16.17",
|
"unstructured==0.16.17",
|
||||||
"nltk==3.9.1",
|
"nltk==3.9.1",
|
||||||
"Markdown==3.7",
|
"Markdown==3.8.2",
|
||||||
"pypandoc==1.15",
|
"pypandoc==1.15",
|
||||||
"pandas==2.2.3",
|
"pandas==2.2.3",
|
||||||
"openpyxl==3.1.5",
|
"openpyxl==3.1.5",
|
||||||
|
|
@ -92,7 +94,7 @@ dependencies = [
|
||||||
"soundfile==0.13.1",
|
"soundfile==0.13.1",
|
||||||
"azure-ai-documentintelligence==1.0.2",
|
"azure-ai-documentintelligence==1.0.2",
|
||||||
|
|
||||||
"pillow==11.2.1",
|
"pillow==11.3.0",
|
||||||
"opencv-python-headless==4.11.0.86",
|
"opencv-python-headless==4.11.0.86",
|
||||||
"rapidocr-onnxruntime==1.4.4",
|
"rapidocr-onnxruntime==1.4.4",
|
||||||
"rank-bm25==0.2.2",
|
"rank-bm25==0.2.2",
|
||||||
|
|
@ -102,7 +104,7 @@ dependencies = [
|
||||||
"faster-whisper==1.1.1",
|
"faster-whisper==1.1.1",
|
||||||
|
|
||||||
"PyJWT[crypto]==2.10.1",
|
"PyJWT[crypto]==2.10.1",
|
||||||
"authlib==1.4.1",
|
"authlib==1.6.1",
|
||||||
|
|
||||||
"black==25.1.0",
|
"black==25.1.0",
|
||||||
"langfuse==2.44.0",
|
"langfuse==2.44.0",
|
||||||
|
|
@ -135,7 +137,7 @@ dependencies = [
|
||||||
"gcp-storage-emulator>=2024.8.3",
|
"gcp-storage-emulator>=2024.8.3",
|
||||||
|
|
||||||
"moto[s3]>=5.0.26",
|
"moto[s3]>=5.0.26",
|
||||||
|
"oracledb>=3.2.0",
|
||||||
"posthog==5.4.0",
|
"posthog==5.4.0",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
21
src/app.css
21
src/app.css
|
|
@ -401,6 +401,17 @@ input[type='number'] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.tiptap .mention {
|
||||||
|
border-radius: 0.4rem;
|
||||||
|
box-decoration-break: clone;
|
||||||
|
padding: 0.1rem 0.3rem;
|
||||||
|
@apply text-blue-900 dark:text-blue-100 bg-blue-300/20 dark:bg-blue-500/20;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tiptap .mention::after {
|
||||||
|
content: '\200B';
|
||||||
|
}
|
||||||
|
|
||||||
.input-prose .tiptap ul[data-type='taskList'] {
|
.input-prose .tiptap ul[data-type='taskList'] {
|
||||||
list-style: none;
|
list-style: none;
|
||||||
margin-left: 0;
|
margin-left: 0;
|
||||||
|
|
@ -616,3 +627,13 @@ input[type='number'] {
|
||||||
padding-right: 2px;
|
padding-right: 2px;
|
||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
background: #fff;
|
||||||
|
color: #000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark body {
|
||||||
|
background: #171717;
|
||||||
|
color: #eee;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,6 @@
|
||||||
document.documentElement.classList.add('light');
|
document.documentElement.classList.add('light');
|
||||||
metaThemeColorTag.setAttribute('content', '#ffffff');
|
metaThemeColorTag.setAttribute('content', '#ffffff');
|
||||||
} else if (localStorage.theme === 'her') {
|
} else if (localStorage.theme === 'her') {
|
||||||
document.documentElement.classList.add('dark');
|
|
||||||
document.documentElement.classList.add('her');
|
document.documentElement.classList.add('her');
|
||||||
metaThemeColorTag.setAttribute('content', '#983724');
|
metaThemeColorTag.setAttribute('content', '#983724');
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -465,7 +465,7 @@ export const executeToolServer = async (
|
||||||
...(token && { authorization: `Bearer ${token}` })
|
...(token && { authorization: `Bearer ${token}` })
|
||||||
};
|
};
|
||||||
|
|
||||||
let requestOptions: RequestInit = {
|
const requestOptions: RequestInit = {
|
||||||
method: httpMethod.toUpperCase(),
|
method: httpMethod.toUpperCase(),
|
||||||
headers
|
headers
|
||||||
};
|
};
|
||||||
|
|
@ -818,7 +818,7 @@ export const generateQueries = async (
|
||||||
model: string,
|
model: string,
|
||||||
messages: object[],
|
messages: object[],
|
||||||
prompt: string,
|
prompt: string,
|
||||||
type?: string = 'web_search'
|
type: string = 'web_search'
|
||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
|
|
@ -1014,7 +1014,7 @@ export const getPipelinesList = async (token: string = '') => {
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
let pipelines = res?.data ?? [];
|
const pipelines = res?.data ?? [];
|
||||||
return pipelines;
|
return pipelines;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1157,7 +1157,7 @@ export const getPipelines = async (token: string, urlIdx?: string) => {
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
let pipelines = res?.data ?? [];
|
const pipelines = res?.data ?? [];
|
||||||
return pipelines;
|
return pipelines;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -331,7 +331,7 @@ export const generateTextCompletion = async (token: string = '', model: string,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const generateChatCompletion = async (token: string = '', body: object) => {
|
export const generateChatCompletion = async (token: string = '', body: object) => {
|
||||||
let controller = new AbortController();
|
const controller = new AbortController();
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/chat`, {
|
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/chat`, {
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ export const getUsers = async (
|
||||||
let error = null;
|
let error = null;
|
||||||
let res = null;
|
let res = null;
|
||||||
|
|
||||||
let searchParams = new URLSearchParams();
|
const searchParams = new URLSearchParams();
|
||||||
|
|
||||||
searchParams.set('page', `${page}`);
|
searchParams.set('page', `${page}`);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,7 @@
|
||||||
let connectionType = 'external';
|
let connectionType = 'external';
|
||||||
let azure = false;
|
let azure = false;
|
||||||
$: azure =
|
$: azure =
|
||||||
(url.includes('azure.com') || url.includes('cognitive.microsoft.com')) && !direct
|
(url.includes('azure.') || url.includes('cognitive.microsoft.com')) && !direct ? true : false;
|
||||||
? true
|
|
||||||
: false;
|
|
||||||
|
|
||||||
let prefixId = '';
|
let prefixId = '';
|
||||||
let enable = true;
|
let enable = true;
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@
|
||||||
|
|
||||||
<div class="w-full h-full absolute top-0 left-0 backdrop-blur-xs bg-black/50"></div>
|
<div class="w-full h-full absolute top-0 left-0 backdrop-blur-xs bg-black/50"></div>
|
||||||
|
|
||||||
<div class="relative bg-transparent w-full min-h-screen flex z-10">
|
<div class="relative bg-transparent w-full h-screen max-h-[100dvh] flex z-10">
|
||||||
<div class="flex flex-col justify-end w-full items-center pb-10 text-center">
|
<div class="flex flex-col justify-end w-full items-center pb-10 text-center">
|
||||||
<div class="text-5xl lg:text-7xl font-secondary">
|
<div class="text-5xl lg:text-7xl font-secondary">
|
||||||
<Marquee
|
<Marquee
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
|
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
|
||||||
import Pencil from '$lib/components/icons/Pencil.svelte';
|
import Pencil from '$lib/components/icons/Pencil.svelte';
|
||||||
import Tooltip from '$lib/components/common/Tooltip.svelte';
|
import Tooltip from '$lib/components/common/Tooltip.svelte';
|
||||||
import Download from '$lib/components/icons/Download.svelte';
|
import Download from '$lib/components/icons/ArrowDownTray.svelte';
|
||||||
|
|
||||||
let show = false;
|
let show = false;
|
||||||
</script>
|
</script>
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,20 @@
|
||||||
<div class="flex flex-col md:flex-row w-full px-5 pb-4 md:space-x-4 dark:text-gray-200">
|
<div class="flex flex-col md:flex-row w-full px-5 pb-4 md:space-x-4 dark:text-gray-200">
|
||||||
{#if loaded}
|
{#if loaded}
|
||||||
<div class="flex flex-col w-full">
|
<div class="flex flex-col w-full">
|
||||||
|
<div class="flex flex-col w-full mb-2">
|
||||||
|
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Chat ID')}</div>
|
||||||
|
|
||||||
|
<div class="flex-1 text-xs">
|
||||||
|
<a
|
||||||
|
href={`/s/${selectedFeedback?.meta?.chat_id}`}
|
||||||
|
class=" hover:underline"
|
||||||
|
target="_blank"
|
||||||
|
>
|
||||||
|
<span>{selectedFeedback?.meta?.chat_id ?? '-'}</span>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
{#if feedbackData}
|
{#if feedbackData}
|
||||||
{@const messageId = feedbackData?.meta?.message_id}
|
{@const messageId = feedbackData?.meta?.message_id}
|
||||||
{@const messages = feedbackData?.snapshot?.chat?.chat?.history.messages}
|
{@const messages = feedbackData?.snapshot?.chat?.chat?.history.messages}
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
|
import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
|
||||||
import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
|
import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
|
||||||
import { WEBUI_BASE_URL } from '$lib/constants';
|
import { WEBUI_BASE_URL } from '$lib/constants';
|
||||||
|
import { config } from '$lib/stores';
|
||||||
|
|
||||||
export let feedbacks = [];
|
export let feedbacks = [];
|
||||||
|
|
||||||
|
|
@ -354,17 +355,20 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</td>
|
</td>
|
||||||
|
|
||||||
|
{#if feedback?.data?.rating}
|
||||||
<td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max">
|
<td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max">
|
||||||
<div class=" flex justify-end">
|
<div class=" flex justify-end">
|
||||||
{#if feedback.data.rating.toString() === '1'}
|
{#if feedback?.data?.rating.toString() === '1'}
|
||||||
<Badge type="info" content={$i18n.t('Won')} />
|
<Badge type="info" content={$i18n.t('Won')} />
|
||||||
{:else if feedback.data.rating.toString() === '0'}
|
{:else if feedback?.data?.rating.toString() === '0'}
|
||||||
<Badge type="muted" content={$i18n.t('Draw')} />
|
<Badge type="muted" content={$i18n.t('Draw')} />
|
||||||
{:else if feedback.data.rating.toString() === '-1'}
|
{:else if feedback?.data?.rating.toString() === '-1'}
|
||||||
<Badge type="error" content={$i18n.t('Lost')} />
|
<Badge type="error" content={$i18n.t('Lost')} />
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
</td>
|
</td>
|
||||||
|
{/if}
|
||||||
|
|
||||||
<td class=" px-3 py-1 text-right font-medium">
|
<td class=" px-3 py-1 text-right font-medium">
|
||||||
{dayjs(feedback.updated_at * 1000).fromNow()}
|
{dayjs(feedback.updated_at * 1000).fromNow()}
|
||||||
|
|
@ -390,7 +394,7 @@
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if feedbacks.length > 0}
|
{#if feedbacks.length > 0 && $config?.features?.enable_community_sharing}
|
||||||
<div class=" flex flex-col justify-end w-full text-right gap-1">
|
<div class=" flex flex-col justify-end w-full text-right gap-1">
|
||||||
<div class="line-clamp-1 text-gray-500 text-xs">
|
<div class="line-clamp-1 text-gray-500 text-xs">
|
||||||
{$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')}
|
{$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')}
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,8 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
feedbacks.forEach((feedback) => {
|
feedbacks.forEach((feedback) => {
|
||||||
|
if (!feedback?.data?.model_id || !feedback?.data?.rating) return;
|
||||||
|
|
||||||
const modelA = feedback.data.model_id;
|
const modelA = feedback.data.model_id;
|
||||||
const statsA = getOrDefaultStats(modelA);
|
const statsA = getOrDefaultStats(modelA);
|
||||||
let outcome: number;
|
let outcome: number;
|
||||||
|
|
@ -334,7 +336,9 @@
|
||||||
onClose={closeLeaderboardModal}
|
onClose={closeLeaderboardModal}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
|
<div
|
||||||
|
class="pt-0.5 pb-2 gap-1 flex flex-col md:flex-row justify-between sticky top-0 z-10 bg-white dark:bg-gray-900"
|
||||||
|
>
|
||||||
<div class="flex md:self-center text-lg font-medium px-0.5 shrink-0 items-center">
|
<div class="flex md:self-center text-lg font-medium px-0.5 shrink-0 items-center">
|
||||||
<div class=" gap-1">
|
<div class=" gap-1">
|
||||||
{$i18n.t('Leaderboard')}
|
{$i18n.t('Leaderboard')}
|
||||||
|
|
|
||||||
|
|
@ -569,7 +569,7 @@
|
||||||
|
|
||||||
<a
|
<a
|
||||||
class=" flex cursor-pointer items-center justify-between hover:bg-gray-50 dark:hover:bg-gray-850 w-full mb-2 px-3.5 py-1.5 rounded-xl transition"
|
class=" flex cursor-pointer items-center justify-between hover:bg-gray-50 dark:hover:bg-gray-850 w-full mb-2 px-3.5 py-1.5 rounded-xl transition"
|
||||||
href="https://openwebui.com/#open-webui-community"
|
href="https://openwebui.com/functions"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
>
|
>
|
||||||
<div class=" self-center">
|
<div class=" self-center">
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,6 @@
|
||||||
const submitHandler = async () => {
|
const submitHandler = async () => {
|
||||||
updateOpenAIHandler();
|
updateOpenAIHandler();
|
||||||
updateOllamaHandler();
|
updateOllamaHandler();
|
||||||
updateDirectConnectionsHandler();
|
|
||||||
|
|
||||||
dispatch('save');
|
dispatch('save');
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,7 @@
|
||||||
<SensitiveInput
|
<SensitiveInput
|
||||||
inputClassName=" outline-hidden bg-transparent w-full"
|
inputClassName=" outline-hidden bg-transparent w-full"
|
||||||
placeholder={$i18n.t('API Key')}
|
placeholder={$i18n.t('API Key')}
|
||||||
|
required={false}
|
||||||
bind:value={key}
|
bind:value={key}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
import { config, user } from '$lib/stores';
|
import { config, user } from '$lib/stores';
|
||||||
import { toast } from 'svelte-sonner';
|
import { toast } from 'svelte-sonner';
|
||||||
import { getAllUserChats } from '$lib/apis/chats';
|
import { getAllUserChats } from '$lib/apis/chats';
|
||||||
|
import { getAllUsers } from '$lib/apis/users';
|
||||||
import { exportConfig, importConfig } from '$lib/apis/configs';
|
import { exportConfig, importConfig } from '$lib/apis/configs';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
@ -20,6 +21,29 @@
|
||||||
saveAs(blob, `all-chats-export-${Date.now()}.json`);
|
saveAs(blob, `all-chats-export-${Date.now()}.json`);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const exportUsers = async () => {
|
||||||
|
const users = await getAllUsers(localStorage.token);
|
||||||
|
|
||||||
|
const headers = ['id', 'name', 'email', 'role'];
|
||||||
|
|
||||||
|
const csv = [
|
||||||
|
headers.join(','),
|
||||||
|
...users.users.map((user) => {
|
||||||
|
return headers
|
||||||
|
.map((header) => {
|
||||||
|
if (user[header] === null || user[header] === undefined) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
return `"${String(user[header]).replace(/"/g, '""')}"`;
|
||||||
|
})
|
||||||
|
.join(',');
|
||||||
|
})
|
||||||
|
].join('\n');
|
||||||
|
|
||||||
|
const blob = new Blob([csv], { type: 'text/csv;charset=utf-8;' });
|
||||||
|
saveAs(blob, 'users.csv');
|
||||||
|
};
|
||||||
|
|
||||||
onMount(async () => {
|
onMount(async () => {
|
||||||
// permissions = await getUserPermissions(localStorage.token);
|
// permissions = await getUserPermissions(localStorage.token);
|
||||||
});
|
});
|
||||||
|
|
@ -180,6 +204,32 @@
|
||||||
{$i18n.t('Export All Chats (All Users)')}
|
{$i18n.t('Export All Chats (All Users)')}
|
||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
|
class=" flex rounded-md py-2 px-3 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
|
||||||
|
on:click={() => {
|
||||||
|
exportUsers();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div class=" self-center mr-3">
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
viewBox="0 0 16 16"
|
||||||
|
fill="currentColor"
|
||||||
|
class="w-4 h-4"
|
||||||
|
>
|
||||||
|
<path d="M2 3a1 1 0 0 1 1-1h10a1 1 0 0 1 1 1v1a1 1 0 0 1-1 1H3a1 1 0 0 1-1-1V3Z" />
|
||||||
|
<path
|
||||||
|
fill-rule="evenodd"
|
||||||
|
d="M13 6H3v6a2 2 0 0 0 2 2h6a2 2 0 0 0 2-2V6ZM8.75 7.75a.75.75 0 0 0-1.5 0v2.69L6.03 9.22a.75.75 0 0 0-1.06 1.06l2.5 2.5a.75.75 0 0 0 1.06 0l2.5-2.5a.75.75 0 1 0-1.06-1.06l-1.22 1.22V7.75Z"
|
||||||
|
clip-rule="evenodd"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<div class=" self-center text-sm font-medium">
|
||||||
|
{$i18n.t('Export Users')}
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -170,6 +170,19 @@
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
RAGConfig.CONTENT_EXTRACTION_ENGINE === 'datalab_marker' &&
|
||||||
|
RAGConfig.DATALAB_MARKER_ADDITIONAL_CONFIG &&
|
||||||
|
RAGConfig.DATALAB_MARKER_ADDITIONAL_CONFIG.trim() !== ''
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
JSON.parse(RAGConfig.DATALAB_MARKER_ADDITIONAL_CONFIG);
|
||||||
|
} catch (e) {
|
||||||
|
toast.error($i18n.t('Invalid JSON format in Additional Config'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
RAGConfig.CONTENT_EXTRACTION_ENGINE === 'document_intelligence' &&
|
RAGConfig.CONTENT_EXTRACTION_ENGINE === 'document_intelligence' &&
|
||||||
(RAGConfig.DOCUMENT_INTELLIGENCE_ENDPOINT === '' ||
|
(RAGConfig.DOCUMENT_INTELLIGENCE_ENDPOINT === '' ||
|
||||||
|
|
@ -195,10 +208,6 @@
|
||||||
ALLOWED_FILE_EXTENSIONS: RAGConfig.ALLOWED_FILE_EXTENSIONS.split(',')
|
ALLOWED_FILE_EXTENSIONS: RAGConfig.ALLOWED_FILE_EXTENSIONS.split(',')
|
||||||
.map((ext) => ext.trim())
|
.map((ext) => ext.trim())
|
||||||
.filter((ext) => ext !== ''),
|
.filter((ext) => ext !== ''),
|
||||||
DATALAB_MARKER_LANGS: RAGConfig.DATALAB_MARKER_LANGS.split(',')
|
|
||||||
.map((code) => code.trim())
|
|
||||||
.filter((code) => code !== '')
|
|
||||||
.join(', '),
|
|
||||||
DOCLING_PICTURE_DESCRIPTION_LOCAL: JSON.parse(
|
DOCLING_PICTURE_DESCRIPTION_LOCAL: JSON.parse(
|
||||||
RAGConfig.DOCLING_PICTURE_DESCRIPTION_LOCAL || '{}'
|
RAGConfig.DOCLING_PICTURE_DESCRIPTION_LOCAL || '{}'
|
||||||
),
|
),
|
||||||
|
|
@ -336,6 +345,21 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{:else if RAGConfig.CONTENT_EXTRACTION_ENGINE === 'datalab_marker'}
|
{:else if RAGConfig.CONTENT_EXTRACTION_ENGINE === 'datalab_marker'}
|
||||||
|
<div class="my-0.5 flex gap-2 pr-2">
|
||||||
|
<Tooltip
|
||||||
|
content={$i18n.t(
|
||||||
|
'API Base URL for Datalab Marker service. Defaults to: https://www.datalab.to/api/v1/marker'
|
||||||
|
)}
|
||||||
|
placement="top-start"
|
||||||
|
className="w-full"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
||||||
|
placeholder={$i18n.t('Enter Datalab Marker API Base URL')}
|
||||||
|
bind:value={RAGConfig.DATALAB_MARKER_API_BASE_URL}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
<div class="my-0.5 flex gap-2 pr-2">
|
<div class="my-0.5 flex gap-2 pr-2">
|
||||||
<SensitiveInput
|
<SensitiveInput
|
||||||
placeholder={$i18n.t('Enter Datalab Marker API Key')}
|
placeholder={$i18n.t('Enter Datalab Marker API Key')}
|
||||||
|
|
@ -344,24 +368,33 @@
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="flex justify-between w-full mt-2">
|
<div class="flex flex-col gap-2 mt-2">
|
||||||
<div class="text-xs font-medium">
|
<div class=" flex flex-col w-full justify-between">
|
||||||
{$i18n.t('Languages')}
|
<div class=" mb-1 text-xs font-medium">
|
||||||
|
{$i18n.t('Additional Config')}
|
||||||
</div>
|
</div>
|
||||||
|
<div class="flex w-full items-center relative">
|
||||||
<input
|
<Tooltip
|
||||||
class="text-sm bg-transparent outline-hidden"
|
content={$i18n.t(
|
||||||
type="text"
|
'Additional configuration options for marker. This should be a JSON string with key-value pairs. For example, \'{"key": "value"}\'. Supported keys include: disable_links, keep_pageheader_in_output, keep_pagefooter_in_output, filter_blank_pages, drop_repeated_text, layout_coverage_threshold, merge_threshold, height_tolerance, gap_threshold, image_threshold, min_line_length, level_count, default_level'
|
||||||
bind:value={RAGConfig.DATALAB_MARKER_LANGS}
|
)}
|
||||||
placeholder={$i18n.t('e.g.) en,fr,de')}
|
placement="top-start"
|
||||||
|
className="w-full"
|
||||||
|
>
|
||||||
|
<Textarea
|
||||||
|
bind:value={RAGConfig.DATALAB_MARKER_ADDITIONAL_CONFIG}
|
||||||
|
placeholder={$i18n.t('Enter JSON config (e.g., {"disable_links": true})')}
|
||||||
/>
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="flex justify-between w-full mt-2">
|
<div class="flex justify-between w-full mt-2">
|
||||||
<div class="self-center text-xs font-medium">
|
<div class="self-center text-xs font-medium">
|
||||||
<Tooltip
|
<Tooltip
|
||||||
content={$i18n.t(
|
content={$i18n.t(
|
||||||
'Significantly improves accuracy by using an LLM to enhance tables, forms, inline math, and layout detection. Will increase latency. Defaults to True.'
|
'Significantly improves accuracy by using an LLM to enhance tables, forms, inline math, and layout detection. Will increase latency. Defaults to False.'
|
||||||
)}
|
)}
|
||||||
placement="top-start"
|
placement="top-start"
|
||||||
>
|
>
|
||||||
|
|
@ -445,6 +478,21 @@
|
||||||
<Switch bind:state={RAGConfig.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION} />
|
<Switch bind:state={RAGConfig.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="flex justify-between w-full mt-2">
|
||||||
|
<div class="self-center text-xs font-medium">
|
||||||
|
<Tooltip
|
||||||
|
content={$i18n.t(
|
||||||
|
'Format the lines in the output. Defaults to False. If set to True, the lines will be formatted to detect inline math and styles.'
|
||||||
|
)}
|
||||||
|
placement="top-start"
|
||||||
|
>
|
||||||
|
{$i18n.t('Format Lines')}
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center">
|
||||||
|
<Switch bind:state={RAGConfig.DATALAB_MARKER_FORMAT_LINES} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<div class="flex justify-between w-full mt-2">
|
<div class="flex justify-between w-full mt-2">
|
||||||
<div class="self-center text-xs font-medium">
|
<div class="self-center text-xs font-medium">
|
||||||
<Tooltip
|
<Tooltip
|
||||||
|
|
@ -1011,24 +1059,73 @@
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
{#if RAGConfig.ENABLE_RAG_HYBRID_SEARCH === true}
|
{#if RAGConfig.ENABLE_RAG_HYBRID_SEARCH === true}
|
||||||
<div class="mb-2.5 flex w-full justify-between">
|
<div class=" mb-2.5 py-0.5 w-full justify-between">
|
||||||
|
<Tooltip
|
||||||
|
content={$i18n.t(
|
||||||
|
'The Weight of BM25 Hybrid Search. 0 more lexical, 1 more semantic. Default 0.5'
|
||||||
|
)}
|
||||||
|
placement="top-start"
|
||||||
|
className="inline-tooltip"
|
||||||
|
>
|
||||||
|
<div class="flex w-full justify-between">
|
||||||
<div class=" self-center text-xs font-medium">
|
<div class=" self-center text-xs font-medium">
|
||||||
{$i18n.t('Weight of BM25 Retrieval')}
|
{$i18n.t('BM25 Weight')}
|
||||||
</div>
|
</div>
|
||||||
<div class="flex items-center relative">
|
<button
|
||||||
|
class="p-1 px-3 text-xs flex rounded-sm transition shrink-0 outline-hidden"
|
||||||
|
type="button"
|
||||||
|
on:click={() => {
|
||||||
|
RAGConfig.HYBRID_BM25_WEIGHT =
|
||||||
|
(RAGConfig?.HYBRID_BM25_WEIGHT ?? null) === null ? 0.5 : null;
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{#if (RAGConfig?.HYBRID_BM25_WEIGHT ?? null) === null}
|
||||||
|
<span class="ml-2 self-center"> {$i18n.t('Default')} </span>
|
||||||
|
{:else}
|
||||||
|
<span class="ml-2 self-center"> {$i18n.t('Custom')} </span>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</Tooltip>
|
||||||
|
|
||||||
|
{#if (RAGConfig?.HYBRID_BM25_WEIGHT ?? null) !== null}
|
||||||
|
<div class="flex mt-0.5 space-x-2">
|
||||||
|
<div class=" flex-1">
|
||||||
<input
|
<input
|
||||||
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
id="steps-range"
|
||||||
type="number"
|
type="range"
|
||||||
step="0.01"
|
min="0"
|
||||||
placeholder={$i18n.t('Enter BM25 Weight')}
|
max="1"
|
||||||
|
step="0.05"
|
||||||
bind:value={RAGConfig.HYBRID_BM25_WEIGHT}
|
bind:value={RAGConfig.HYBRID_BM25_WEIGHT}
|
||||||
autocomplete="off"
|
class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
|
||||||
min="0.0"
|
/>
|
||||||
max="1.0"
|
|
||||||
|
<div class="py-0.5">
|
||||||
|
<div class="flex w-full justify-between">
|
||||||
|
<div class=" text-left text-xs font-small">
|
||||||
|
{$i18n.t('lexical')}
|
||||||
|
</div>
|
||||||
|
<div class=" text-right text-xs font-small">
|
||||||
|
{$i18n.t('semantic')}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<input
|
||||||
|
bind:value={RAGConfig.HYBRID_BM25_WEIGHT}
|
||||||
|
type="number"
|
||||||
|
class=" bg-transparent text-center w-14"
|
||||||
|
min="0"
|
||||||
|
max="1"
|
||||||
|
step="any"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
<div class=" mb-2.5 flex flex-col w-full justify-between">
|
<div class=" mb-2.5 flex flex-col w-full justify-between">
|
||||||
|
|
|
||||||
|
|
@ -57,14 +57,6 @@
|
||||||
await config.set(await getBackendConfig());
|
await config.set(await getBackendConfig());
|
||||||
};
|
};
|
||||||
|
|
||||||
onMount(async () => {
|
|
||||||
await init();
|
|
||||||
taskConfig = await getTaskConfig(localStorage.token);
|
|
||||||
|
|
||||||
promptSuggestions = $config?.default_prompt_suggestions ?? [];
|
|
||||||
banners = await getBanners(localStorage.token);
|
|
||||||
});
|
|
||||||
|
|
||||||
const updateBanners = async () => {
|
const updateBanners = async () => {
|
||||||
_banners.set(await setBanners(localStorage.token, banners));
|
_banners.set(await setBanners(localStorage.token, banners));
|
||||||
};
|
};
|
||||||
|
|
@ -75,6 +67,10 @@
|
||||||
let models = null;
|
let models = null;
|
||||||
|
|
||||||
const init = async () => {
|
const init = async () => {
|
||||||
|
taskConfig = await getTaskConfig(localStorage.token);
|
||||||
|
promptSuggestions = $config?.default_prompt_suggestions ?? [];
|
||||||
|
banners = await getBanners(localStorage.token);
|
||||||
|
|
||||||
workspaceModels = await getBaseModels(localStorage.token);
|
workspaceModels = await getBaseModels(localStorage.token);
|
||||||
baseModels = await getModels(localStorage.token, null, false);
|
baseModels = await getModels(localStorage.token, null, false);
|
||||||
|
|
||||||
|
|
@ -99,6 +95,10 @@
|
||||||
|
|
||||||
console.debug('models', models);
|
console.debug('models', models);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
onMount(async () => {
|
||||||
|
await init();
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
{#if models !== null && taskConfig}
|
{#if models !== null && taskConfig}
|
||||||
|
|
@ -460,25 +460,27 @@
|
||||||
<div class="grid lg:grid-cols-2 flex-col gap-1.5">
|
<div class="grid lg:grid-cols-2 flex-col gap-1.5">
|
||||||
{#each promptSuggestions as prompt, promptIdx}
|
{#each promptSuggestions as prompt, promptIdx}
|
||||||
<div
|
<div
|
||||||
class=" flex border border-gray-100 dark:border-none dark:bg-gray-850 rounded-xl py-1.5"
|
class=" flex border rounded-xl border-gray-50 dark:border-none dark:bg-gray-850 py-1.5"
|
||||||
>
|
>
|
||||||
<div class="flex flex-col flex-1 pl-1">
|
<div class="flex flex-col flex-1 pl-1">
|
||||||
<div class="flex border-b border-gray-100 dark:border-gray-850 w-full">
|
<div class="py-1 gap-1">
|
||||||
<input
|
<input
|
||||||
class="px-3 py-1.5 text-xs w-full bg-transparent outline-hidden border-r border-gray-100 dark:border-gray-850"
|
class="px-3 text-sm font-medium w-full bg-transparent outline-hidden"
|
||||||
placeholder={$i18n.t('Title (e.g. Tell me a fun fact)')}
|
placeholder={$i18n.t('Title (e.g. Tell me a fun fact)')}
|
||||||
bind:value={prompt.title[0]}
|
bind:value={prompt.title[0]}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<input
|
<input
|
||||||
class="px-3 py-1.5 text-xs w-full bg-transparent outline-hidden border-r border-gray-100 dark:border-gray-850"
|
class="px-3 text-xs w-full bg-transparent outline-hidden text-gray-600 dark:text-gray-400"
|
||||||
placeholder={$i18n.t('Subtitle (e.g. about the Roman Empire)')}
|
placeholder={$i18n.t('Subtitle (e.g. about the Roman Empire)')}
|
||||||
bind:value={prompt.title[1]}
|
bind:value={prompt.title[1]}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<hr class="border-gray-50 dark:border-gray-850 my-1" />
|
||||||
|
|
||||||
<textarea
|
<textarea
|
||||||
class="px-3 py-1.5 text-xs w-full bg-transparent outline-hidden border-r border-gray-100 dark:border-gray-850 resize-none"
|
class="px-3 py-1.5 text-xs w-full bg-transparent outline-hidden resize-none"
|
||||||
placeholder={$i18n.t(
|
placeholder={$i18n.t(
|
||||||
'Prompt (e.g. Tell me a fun fact about the Roman Empire)'
|
'Prompt (e.g. Tell me a fun fact about the Roman Empire)'
|
||||||
)}
|
)}
|
||||||
|
|
@ -487,8 +489,9 @@
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="">
|
||||||
<button
|
<button
|
||||||
class="px-3"
|
class="p-3"
|
||||||
type="button"
|
type="button"
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
promptSuggestions.splice(promptIdx, 1);
|
promptSuggestions.splice(promptIdx, 1);
|
||||||
|
|
@ -507,6 +510,7 @@
|
||||||
</svg>
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
{/each}
|
{/each}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,9 @@
|
||||||
},
|
},
|
||||||
chat: {
|
chat: {
|
||||||
controls: true,
|
controls: true,
|
||||||
|
valves: true,
|
||||||
system_prompt: true,
|
system_prompt: true,
|
||||||
|
params: true,
|
||||||
file_upload: true,
|
file_upload: true,
|
||||||
delete: true,
|
delete: true,
|
||||||
edit: true,
|
edit: true,
|
||||||
|
|
|
||||||
|
|
@ -48,10 +48,20 @@
|
||||||
},
|
},
|
||||||
chat: {
|
chat: {
|
||||||
controls: true,
|
controls: true,
|
||||||
|
valves: true,
|
||||||
|
system_prompt: true,
|
||||||
|
params: true,
|
||||||
file_upload: true,
|
file_upload: true,
|
||||||
delete: true,
|
delete: true,
|
||||||
edit: true,
|
edit: true,
|
||||||
temporary: true
|
share: true,
|
||||||
|
export: true,
|
||||||
|
stt: true,
|
||||||
|
tts: true,
|
||||||
|
call: true,
|
||||||
|
multiple_models: true,
|
||||||
|
temporary: true,
|
||||||
|
temporary_enforced: false
|
||||||
},
|
},
|
||||||
features: {
|
features: {
|
||||||
direct_tool_servers: false,
|
direct_tool_servers: false,
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,9 @@
|
||||||
},
|
},
|
||||||
chat: {
|
chat: {
|
||||||
controls: true,
|
controls: true,
|
||||||
|
valves: true,
|
||||||
|
system_prompt: true,
|
||||||
|
params: true,
|
||||||
file_upload: true,
|
file_upload: true,
|
||||||
delete: true,
|
delete: true,
|
||||||
edit: true,
|
edit: true,
|
||||||
|
|
@ -263,6 +266,15 @@
|
||||||
<Switch bind:state={permissions.chat.controls} />
|
<Switch bind:state={permissions.chat.controls} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{#if permissions.chat.controls}
|
||||||
|
<div class=" flex w-full justify-between my-2 pr-2">
|
||||||
|
<div class=" self-center text-xs font-medium">
|
||||||
|
{$i18n.t('Allow Chat Valves')}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Switch bind:state={permissions.chat.valves} />
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class=" flex w-full justify-between my-2 pr-2">
|
<div class=" flex w-full justify-between my-2 pr-2">
|
||||||
<div class=" self-center text-xs font-medium">
|
<div class=" self-center text-xs font-medium">
|
||||||
{$i18n.t('Allow Chat System Prompt')}
|
{$i18n.t('Allow Chat System Prompt')}
|
||||||
|
|
@ -271,6 +283,15 @@
|
||||||
<Switch bind:state={permissions.chat.system_prompt} />
|
<Switch bind:state={permissions.chat.system_prompt} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class=" flex w-full justify-between my-2 pr-2">
|
||||||
|
<div class=" self-center text-xs font-medium">
|
||||||
|
{$i18n.t('Allow Chat Params')}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Switch bind:state={permissions.chat.params} />
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
|
||||||
<div class=" flex w-full justify-between my-2 pr-2">
|
<div class=" flex w-full justify-between my-2 pr-2">
|
||||||
<div class=" self-center text-xs font-medium">
|
<div class=" self-center text-xs font-medium">
|
||||||
{$i18n.t('Allow Chat Delete')}
|
{$i18n.t('Allow Chat Delete')}
|
||||||
|
|
|
||||||
|
|
@ -142,8 +142,7 @@
|
||||||
type: 'error',
|
type: 'error',
|
||||||
title: 'License Error',
|
title: 'License Error',
|
||||||
content:
|
content:
|
||||||
'Exceeded the number of seats in your license. Please contact support to increase the number of seats.',
|
'Exceeded the number of seats in your license. Please contact support to increase the number of seats.'
|
||||||
dismissable: true
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -154,7 +153,9 @@
|
||||||
<Spinner className="size-5" />
|
<Spinner className="size-5" />
|
||||||
</div>
|
</div>
|
||||||
{:else}
|
{:else}
|
||||||
<div class="mt-0.5 mb-2 gap-1 flex flex-col md:flex-row justify-between">
|
<div
|
||||||
|
class="pt-0.5 pb-2 gap-1 flex flex-col md:flex-row justify-between sticky top-0 z-10 bg-white dark:bg-gray-900"
|
||||||
|
>
|
||||||
<div class="flex md:self-center text-lg font-medium px-0.5">
|
<div class="flex md:self-center text-lg font-medium px-0.5">
|
||||||
<div class="flex-shrink-0">
|
<div class="flex-shrink-0">
|
||||||
{$i18n.t('Users')}
|
{$i18n.t('Users')}
|
||||||
|
|
@ -494,8 +495,10 @@
|
||||||
ⓘ {$i18n.t("Click on the user role button to change a user's role.")}
|
ⓘ {$i18n.t("Click on the user role button to change a user's role.")}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{#if total > 30}
|
||||||
<Pagination bind:page count={total} perPage={30} />
|
<Pagination bind:page count={total} perPage={30} />
|
||||||
{/if}
|
{/if}
|
||||||
|
{/if}
|
||||||
|
|
||||||
{#if !$config?.license_metadata}
|
{#if !$config?.license_metadata}
|
||||||
{#if total > 50}
|
{#if total > 50}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
import Modal from '$lib/components/common/Modal.svelte';
|
import Modal from '$lib/components/common/Modal.svelte';
|
||||||
import { generateInitialsImage } from '$lib/utils';
|
import { generateInitialsImage } from '$lib/utils';
|
||||||
import XMark from '$lib/components/icons/XMark.svelte';
|
import XMark from '$lib/components/icons/XMark.svelte';
|
||||||
|
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
const dispatch = createEventDispatcher();
|
const dispatch = createEventDispatcher();
|
||||||
|
|
@ -224,12 +225,13 @@
|
||||||
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div>
|
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div>
|
||||||
|
|
||||||
<div class="flex-1">
|
<div class="flex-1">
|
||||||
<input
|
<SensitiveInput
|
||||||
class="w-full text-sm bg-transparent disabled:text-gray-500 dark:disabled:text-gray-500 outline-hidden"
|
class="w-full text-sm bg-transparent disabled:text-gray-500 dark:disabled:text-gray-500 outline-hidden"
|
||||||
type="password"
|
type="password"
|
||||||
bind:value={_user.password}
|
bind:value={_user.password}
|
||||||
placeholder={$i18n.t('Enter Your Password')}
|
placeholder={$i18n.t('Enter Your Password')}
|
||||||
autocomplete="off"
|
autocomplete="off"
|
||||||
|
required
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
import Modal from '$lib/components/common/Modal.svelte';
|
import Modal from '$lib/components/common/Modal.svelte';
|
||||||
import localizedFormat from 'dayjs/plugin/localizedFormat';
|
import localizedFormat from 'dayjs/plugin/localizedFormat';
|
||||||
import XMark from '$lib/components/icons/XMark.svelte';
|
import XMark from '$lib/components/icons/XMark.svelte';
|
||||||
|
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
const dispatch = createEventDispatcher();
|
const dispatch = createEventDispatcher();
|
||||||
|
|
@ -139,12 +140,13 @@
|
||||||
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('New Password')}</div>
|
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('New Password')}</div>
|
||||||
|
|
||||||
<div class="flex-1">
|
<div class="flex-1">
|
||||||
<input
|
<SensitiveInput
|
||||||
class="w-full text-sm bg-transparent outline-hidden"
|
class="w-full text-sm bg-transparent outline-hidden"
|
||||||
type="password"
|
type="password"
|
||||||
placeholder={$i18n.t('Enter New Password')}
|
placeholder={$i18n.t('Enter New Password')}
|
||||||
bind:value={_user.password}
|
bind:value={_user.password}
|
||||||
autocomplete="new-password"
|
autocomplete="new-password"
|
||||||
|
required={false}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@
|
||||||
import InputVariablesModal from '../chat/MessageInput/InputVariablesModal.svelte';
|
import InputVariablesModal from '../chat/MessageInput/InputVariablesModal.svelte';
|
||||||
|
|
||||||
export let placeholder = $i18n.t('Send a Message');
|
export let placeholder = $i18n.t('Send a Message');
|
||||||
export let transparentBackground = false;
|
|
||||||
|
|
||||||
export let id = null;
|
export let id = null;
|
||||||
|
|
||||||
|
|
@ -60,7 +59,7 @@
|
||||||
export let scrollToBottom: Function = () => {};
|
export let scrollToBottom: Function = () => {};
|
||||||
|
|
||||||
export let acceptFiles = true;
|
export let acceptFiles = true;
|
||||||
export let showFormattingButtons = true;
|
export let showFormattingToolbar = true;
|
||||||
|
|
||||||
let showInputVariablesModal = false;
|
let showInputVariablesModal = false;
|
||||||
let inputVariables: Record<string, any> = {};
|
let inputVariables: Record<string, any> = {};
|
||||||
|
|
@ -327,7 +326,9 @@
|
||||||
let imageUrl = event.target.result;
|
let imageUrl = event.target.result;
|
||||||
|
|
||||||
// Compress the image if settings or config require it
|
// Compress the image if settings or config require it
|
||||||
|
if ($settings?.imageCompression && $settings?.imageCompressionInChannels) {
|
||||||
imageUrl = await compressImageHandler(imageUrl, $settings, $config);
|
imageUrl = await compressImageHandler(imageUrl, $settings, $config);
|
||||||
|
}
|
||||||
|
|
||||||
files = [
|
files = [
|
||||||
...files,
|
...files,
|
||||||
|
|
@ -700,7 +701,7 @@
|
||||||
bind:this={chatInputElement}
|
bind:this={chatInputElement}
|
||||||
json={true}
|
json={true}
|
||||||
messageInput={true}
|
messageInput={true}
|
||||||
{showFormattingButtons}
|
{showFormattingToolbar}
|
||||||
shiftEnter={!($settings?.ctrlEnterToSend ?? false) &&
|
shiftEnter={!($settings?.ctrlEnterToSend ?? false) &&
|
||||||
(!$mobile ||
|
(!$mobile ||
|
||||||
!(
|
!(
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@
|
||||||
import Image from '$lib/components/common/Image.svelte';
|
import Image from '$lib/components/common/Image.svelte';
|
||||||
import FileItem from '$lib/components/common/FileItem.svelte';
|
import FileItem from '$lib/components/common/FileItem.svelte';
|
||||||
import ProfilePreview from './Message/ProfilePreview.svelte';
|
import ProfilePreview from './Message/ProfilePreview.svelte';
|
||||||
import ChatBubbleOvalEllipsis from '$lib/components/icons/ChatBubbleOvalEllipsis.svelte';
|
import ChatBubbleOvalEllipsis from '$lib/components/icons/ChatBubble.svelte';
|
||||||
import FaceSmile from '$lib/components/icons/FaceSmile.svelte';
|
import FaceSmile from '$lib/components/icons/FaceSmile.svelte';
|
||||||
import ReactionPicker from './Message/ReactionPicker.svelte';
|
import ReactionPicker from './Message/ReactionPicker.svelte';
|
||||||
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
|
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,15 @@
|
||||||
import { getContext } from 'svelte';
|
import { getContext } from 'svelte';
|
||||||
import { toast } from 'svelte-sonner';
|
import { toast } from 'svelte-sonner';
|
||||||
|
|
||||||
import { showArchivedChats, showSidebar, user } from '$lib/stores';
|
import { mobile, showArchivedChats, showSidebar, user } from '$lib/stores';
|
||||||
|
|
||||||
import { slide } from 'svelte/transition';
|
import { slide } from 'svelte/transition';
|
||||||
import { page } from '$app/stores';
|
import { page } from '$app/stores';
|
||||||
|
|
||||||
import UserMenu from '$lib/components/layout/Sidebar/UserMenu.svelte';
|
import UserMenu from '$lib/components/layout/Sidebar/UserMenu.svelte';
|
||||||
import MenuLines from '../icons/MenuLines.svelte';
|
|
||||||
import PencilSquare from '../icons/PencilSquare.svelte';
|
import PencilSquare from '../icons/PencilSquare.svelte';
|
||||||
|
import Tooltip from '../common/Tooltip.svelte';
|
||||||
|
import Sidebar from '../icons/Sidebar.svelte';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
|
|
@ -23,24 +24,30 @@
|
||||||
|
|
||||||
<div class=" flex max-w-full w-full mx-auto px-1 pt-0.5 bg-transparent">
|
<div class=" flex max-w-full w-full mx-auto px-1 pt-0.5 bg-transparent">
|
||||||
<div class="flex items-center w-full max-w-full">
|
<div class="flex items-center w-full max-w-full">
|
||||||
|
{#if $mobile}
|
||||||
<div
|
<div
|
||||||
class="{$showSidebar
|
class="{$showSidebar
|
||||||
? 'md:hidden'
|
? 'md:hidden'
|
||||||
: ''} mr-1 self-start flex flex-none items-center text-gray-600 dark:text-gray-400"
|
: ''} mr-1.5 mt-0.5 self-start flex flex-none items-center text-gray-600 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
<Tooltip
|
||||||
|
content={$showSidebar ? $i18n.t('Close Sidebar') : $i18n.t('Open Sidebar')}
|
||||||
|
interactive={true}
|
||||||
>
|
>
|
||||||
<button
|
<button
|
||||||
id="sidebar-toggle-button"
|
id="sidebar-toggle-button"
|
||||||
class="cursor-pointer px-2 py-2 flex rounded-xl hover:bg-gray-50 dark:hover:bg-gray-850 transition"
|
class=" cursor-pointer flex rounded-lg hover:bg-gray-100 dark:hover:bg-gray-850 transition cursor-"
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
showSidebar.set(!$showSidebar);
|
showSidebar.set(!$showSidebar);
|
||||||
}}
|
}}
|
||||||
aria-label="Toggle Sidebar"
|
|
||||||
>
|
>
|
||||||
<div class=" m-auto self-center">
|
<div class=" self-center p-1.5">
|
||||||
<MenuLines />
|
<Sidebar />
|
||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
|
</Tooltip>
|
||||||
</div>
|
</div>
|
||||||
|
{/if}
|
||||||
|
|
||||||
<div
|
<div
|
||||||
class="flex-1 overflow-hidden max-w-full py-0.5
|
class="flex-1 overflow-hidden max-w-full py-0.5
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,8 @@
|
||||||
import NotificationToast from '../NotificationToast.svelte';
|
import NotificationToast from '../NotificationToast.svelte';
|
||||||
import Spinner from '../common/Spinner.svelte';
|
import Spinner from '../common/Spinner.svelte';
|
||||||
import { fade } from 'svelte/transition';
|
import { fade } from 'svelte/transition';
|
||||||
|
import Tooltip from '../common/Tooltip.svelte';
|
||||||
|
import Sidebar from '../icons/Sidebar.svelte';
|
||||||
|
|
||||||
export let chatIdProp = '';
|
export let chatIdProp = '';
|
||||||
|
|
||||||
|
|
@ -128,6 +130,9 @@
|
||||||
|
|
||||||
let showCommands = false;
|
let showCommands = false;
|
||||||
|
|
||||||
|
let generating = false;
|
||||||
|
let generationController = null;
|
||||||
|
|
||||||
let chat = null;
|
let chat = null;
|
||||||
let tags = [];
|
let tags = [];
|
||||||
|
|
||||||
|
|
@ -1479,14 +1484,23 @@
|
||||||
|
|
||||||
saveSessionSelectedModels();
|
saveSessionSelectedModels();
|
||||||
|
|
||||||
await sendPrompt(history, userPrompt, userMessageId, { newChat: true });
|
await sendMessage(history, userMessageId, { newChat: true });
|
||||||
};
|
};
|
||||||
|
|
||||||
const sendPrompt = async (
|
const sendMessage = async (
|
||||||
_history,
|
_history,
|
||||||
prompt: string,
|
|
||||||
parentId: string,
|
parentId: string,
|
||||||
{ modelId = null, modelIdx = null, newChat = false } = {}
|
{
|
||||||
|
messages = null,
|
||||||
|
modelId = null,
|
||||||
|
modelIdx = null,
|
||||||
|
newChat = false
|
||||||
|
}: {
|
||||||
|
messages?: any[] | null;
|
||||||
|
modelId?: string | null;
|
||||||
|
modelIdx?: number | null;
|
||||||
|
newChat?: boolean;
|
||||||
|
} = {}
|
||||||
) => {
|
) => {
|
||||||
if (autoScroll) {
|
if (autoScroll) {
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
|
|
@ -1556,9 +1570,8 @@
|
||||||
const model = $models.filter((m) => m.id === modelId).at(0);
|
const model = $models.filter((m) => m.id === modelId).at(0);
|
||||||
|
|
||||||
if (model) {
|
if (model) {
|
||||||
const messages = createMessagesList(_history, parentId);
|
|
||||||
// If there are image files, check if model is vision capable
|
// If there are image files, check if model is vision capable
|
||||||
const hasImages = messages.some((message) =>
|
const hasImages = createMessagesList(_history, parentId).some((message) =>
|
||||||
message.files?.some((file) => file.type === 'image')
|
message.files?.some((file) => file.type === 'image')
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -1575,7 +1588,15 @@
|
||||||
const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
|
const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
|
||||||
|
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
await sendPromptSocket(_history, model, responseMessageId, _chatId);
|
await sendMessageSocket(
|
||||||
|
model,
|
||||||
|
messages && messages.length > 0
|
||||||
|
? messages
|
||||||
|
: createMessagesList(_history, responseMessageId),
|
||||||
|
_history,
|
||||||
|
responseMessageId,
|
||||||
|
_chatId
|
||||||
|
);
|
||||||
|
|
||||||
if (chatEventEmitter) clearInterval(chatEventEmitter);
|
if (chatEventEmitter) clearInterval(chatEventEmitter);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1588,12 +1609,11 @@
|
||||||
chats.set(await getChatList(localStorage.token, $currentChatPage));
|
chats.set(await getChatList(localStorage.token, $currentChatPage));
|
||||||
};
|
};
|
||||||
|
|
||||||
const sendPromptSocket = async (_history, model, responseMessageId, _chatId) => {
|
const sendMessageSocket = async (model, _messages, _history, responseMessageId, _chatId) => {
|
||||||
const chatMessages = createMessagesList(history, history.currentId);
|
|
||||||
const responseMessage = _history.messages[responseMessageId];
|
const responseMessage = _history.messages[responseMessageId];
|
||||||
const userMessage = _history.messages[responseMessage.parentId];
|
const userMessage = _history.messages[responseMessage.parentId];
|
||||||
|
|
||||||
const chatMessageFiles = chatMessages
|
const chatMessageFiles = _messages
|
||||||
.filter((message) => message.files)
|
.filter((message) => message.files)
|
||||||
.flatMap((message) => message.files);
|
.flatMap((message) => message.files);
|
||||||
|
|
||||||
|
|
@ -1647,7 +1667,7 @@
|
||||||
)}`
|
)}`
|
||||||
}
|
}
|
||||||
: undefined,
|
: undefined,
|
||||||
...createMessagesList(_history, responseMessageId).map((message) => ({
|
..._messages.map((message) => ({
|
||||||
...message,
|
...message,
|
||||||
content: processDetails(message.content)
|
content: processDetails(message.content)
|
||||||
}))
|
}))
|
||||||
|
|
@ -1857,6 +1877,12 @@
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (generating) {
|
||||||
|
generating = false;
|
||||||
|
generationController?.abort();
|
||||||
|
generationController = null;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const submitMessage = async (parentId, prompt) => {
|
const submitMessage = async (parentId, prompt) => {
|
||||||
|
|
@ -1889,31 +1915,39 @@
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendPrompt(history, userPrompt, userMessageId);
|
await sendMessage(history, userMessageId);
|
||||||
};
|
};
|
||||||
|
|
||||||
const regenerateResponse = async (message) => {
|
const regenerateResponse = async (message, suggestionPrompt = null) => {
|
||||||
console.log('regenerateResponse');
|
console.log('regenerateResponse');
|
||||||
|
|
||||||
if (history.currentId) {
|
if (history.currentId) {
|
||||||
let userMessage = history.messages[message.parentId];
|
let userMessage = history.messages[message.parentId];
|
||||||
let userPrompt = userMessage.content;
|
|
||||||
|
|
||||||
if (autoScroll) {
|
if (autoScroll) {
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((userMessage?.models ?? [...selectedModels]).length == 1) {
|
await sendMessage(history, userMessage.id, {
|
||||||
// If user message has only one model selected, sendPrompt automatically selects it for regeneration
|
...(suggestionPrompt
|
||||||
await sendPrompt(history, userPrompt, userMessage.id);
|
? {
|
||||||
} else {
|
messages: [
|
||||||
// If there are multiple models selected, use the model of the response message for regeneration
|
...createMessagesList(history, message.id),
|
||||||
// e.g. many model chat
|
{
|
||||||
await sendPrompt(history, userPrompt, userMessage.id, {
|
role: 'user',
|
||||||
|
content: suggestionPrompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
|
...((userMessage?.models ?? [...selectedModels]).length > 1
|
||||||
|
? {
|
||||||
|
// If multiple models are selected, use the model from the message
|
||||||
modelId: message.model,
|
modelId: message.model,
|
||||||
modelIdx: message.modelIdx
|
modelIdx: message.modelIdx
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
: {})
|
||||||
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1931,7 +1965,13 @@
|
||||||
.at(0);
|
.at(0);
|
||||||
|
|
||||||
if (model) {
|
if (model) {
|
||||||
await sendPromptSocket(history, model, responseMessage.id, _chatId);
|
await sendMessageSocket(
|
||||||
|
model,
|
||||||
|
createMessagesList(history, responseMessage.id),
|
||||||
|
history,
|
||||||
|
responseMessage.id,
|
||||||
|
_chatId
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1947,6 +1987,7 @@
|
||||||
history.messages[messageId] = message;
|
history.messages[messageId] = message;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
generating = true;
|
||||||
const [res, controller] = await generateMoACompletion(
|
const [res, controller] = await generateMoACompletion(
|
||||||
localStorage.token,
|
localStorage.token,
|
||||||
message.model,
|
message.model,
|
||||||
|
|
@ -1954,11 +1995,14 @@
|
||||||
responses
|
responses
|
||||||
);
|
);
|
||||||
|
|
||||||
if (res && res.ok && res.body) {
|
if (res && res.ok && res.body && generating) {
|
||||||
|
generationController = controller;
|
||||||
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
|
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
|
||||||
for await (const update of textStream) {
|
for await (const update of textStream) {
|
||||||
const { value, done, sources, error, usage } = update;
|
const { value, done, sources, error, usage } = update;
|
||||||
if (error || done) {
|
if (error || done) {
|
||||||
|
generating = false;
|
||||||
|
generationController = null;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2038,6 +2082,33 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const MAX_DRAFT_LENGTH = 5000;
|
||||||
|
let saveDraftTimeout = null;
|
||||||
|
|
||||||
|
const saveDraft = async (draft, chatId = null) => {
|
||||||
|
if (saveDraftTimeout) {
|
||||||
|
clearTimeout(saveDraftTimeout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (draft.prompt !== null && draft.prompt.length < MAX_DRAFT_LENGTH) {
|
||||||
|
saveDraftTimeout = setTimeout(async () => {
|
||||||
|
await sessionStorage.setItem(
|
||||||
|
`chat-input${chatId ? `-${chatId}` : ''}`,
|
||||||
|
JSON.stringify(draft)
|
||||||
|
);
|
||||||
|
}, 500);
|
||||||
|
} else {
|
||||||
|
sessionStorage.removeItem(`chat-input${chatId ? `-${chatId}` : ''}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const clearDraft = async (chatId = null) => {
|
||||||
|
if (saveDraftTimeout) {
|
||||||
|
clearTimeout(saveDraftTimeout);
|
||||||
|
}
|
||||||
|
await sessionStorage.removeItem(`chat-input${chatId ? `-${chatId}` : ''}`);
|
||||||
|
};
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<svelte:head>
|
<svelte:head>
|
||||||
|
|
@ -2137,7 +2208,7 @@
|
||||||
}}
|
}}
|
||||||
{selectedModels}
|
{selectedModels}
|
||||||
{atSelectedModel}
|
{atSelectedModel}
|
||||||
{sendPrompt}
|
{sendMessage}
|
||||||
{showMessage}
|
{showMessage}
|
||||||
{submitMessage}
|
{submitMessage}
|
||||||
{continueResponse}
|
{continueResponse}
|
||||||
|
|
@ -2145,6 +2216,7 @@
|
||||||
{mergeResponses}
|
{mergeResponses}
|
||||||
{chatActionHandler}
|
{chatActionHandler}
|
||||||
{addMessages}
|
{addMessages}
|
||||||
|
topPadding={true}
|
||||||
bottomPadding={files.length > 0}
|
bottomPadding={files.length > 0}
|
||||||
{onSelect}
|
{onSelect}
|
||||||
/>
|
/>
|
||||||
|
|
@ -2168,21 +2240,12 @@
|
||||||
bind:atSelectedModel
|
bind:atSelectedModel
|
||||||
bind:showCommands
|
bind:showCommands
|
||||||
toolServers={$toolServers}
|
toolServers={$toolServers}
|
||||||
transparentBackground={$settings?.backgroundImageUrl ??
|
{generating}
|
||||||
$config?.license_metadata?.background_image_url ??
|
|
||||||
false}
|
|
||||||
{stopResponse}
|
{stopResponse}
|
||||||
{createMessagePair}
|
{createMessagePair}
|
||||||
onChange={(input) => {
|
onChange={(data) => {
|
||||||
if (!$temporaryChatEnabled) {
|
if (!$temporaryChatEnabled) {
|
||||||
if (input.prompt !== null) {
|
saveDraft(data, $chatId);
|
||||||
sessionStorage.setItem(
|
|
||||||
`chat-input${$chatId ? `-${$chatId}` : ''}`,
|
|
||||||
JSON.stringify(input)
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
sessionStorage.removeItem(`chat-input${$chatId ? `-${$chatId}` : ''}`);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
on:upload={async (e) => {
|
on:upload={async (e) => {
|
||||||
|
|
@ -2197,6 +2260,7 @@
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
on:submit={async (e) => {
|
on:submit={async (e) => {
|
||||||
|
clearDraft();
|
||||||
if (e.detail || files.length > 0) {
|
if (e.detail || files.length > 0) {
|
||||||
await tick();
|
await tick();
|
||||||
submitPrompt(
|
submitPrompt(
|
||||||
|
|
@ -2230,13 +2294,15 @@
|
||||||
bind:webSearchEnabled
|
bind:webSearchEnabled
|
||||||
bind:atSelectedModel
|
bind:atSelectedModel
|
||||||
bind:showCommands
|
bind:showCommands
|
||||||
transparentBackground={$settings?.backgroundImageUrl ??
|
|
||||||
$config?.license_metadata?.background_image_url ??
|
|
||||||
false}
|
|
||||||
toolServers={$toolServers}
|
toolServers={$toolServers}
|
||||||
{stopResponse}
|
{stopResponse}
|
||||||
{createMessagePair}
|
{createMessagePair}
|
||||||
{onSelect}
|
{onSelect}
|
||||||
|
onChange={(data) => {
|
||||||
|
if (!$temporaryChatEnabled) {
|
||||||
|
saveDraft(data);
|
||||||
|
}
|
||||||
|
}}
|
||||||
on:upload={async (e) => {
|
on:upload={async (e) => {
|
||||||
const { type, data } = e.detail;
|
const { type, data } = e.detail;
|
||||||
|
|
||||||
|
|
@ -2247,6 +2313,7 @@
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
on:submit={async (e) => {
|
on:submit={async (e) => {
|
||||||
|
clearDraft();
|
||||||
if (e.detail || files.length > 0) {
|
if (e.detail || files.length > 0) {
|
||||||
await tick();
|
await tick();
|
||||||
submitPrompt(
|
submitPrompt(
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
import DOMPurify from 'dompurify';
|
import DOMPurify from 'dompurify';
|
||||||
import { marked } from 'marked';
|
import { marked } from 'marked';
|
||||||
|
|
||||||
import { getContext, tick } from 'svelte';
|
import { getContext, tick, onDestroy } from 'svelte';
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
import { chatCompletion } from '$lib/apis/openai';
|
import { chatCompletion } from '$lib/apis/openai';
|
||||||
|
|
@ -17,135 +17,124 @@
|
||||||
export let id = '';
|
export let id = '';
|
||||||
export let model = null;
|
export let model = null;
|
||||||
export let messages = [];
|
export let messages = [];
|
||||||
export let onAdd = () => {};
|
export let actions = [];
|
||||||
|
export let onAdd = (e) => {};
|
||||||
|
|
||||||
let floatingInput = false;
|
let floatingInput = false;
|
||||||
|
let selectedAction = null;
|
||||||
|
|
||||||
let selectedText = '';
|
let selectedText = '';
|
||||||
let floatingInputValue = '';
|
let floatingInputValue = '';
|
||||||
|
|
||||||
let prompt = '';
|
let content = '';
|
||||||
let responseContent = null;
|
let responseContent = null;
|
||||||
let responseDone = false;
|
let responseDone = false;
|
||||||
|
let controller = null;
|
||||||
|
|
||||||
|
$: if (actions.length === 0) {
|
||||||
|
actions = DEFAULT_ACTIONS;
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_ACTIONS = [
|
||||||
|
{
|
||||||
|
id: 'ask',
|
||||||
|
label: $i18n.t('Ask'),
|
||||||
|
icon: ChatBubble,
|
||||||
|
input: true,
|
||||||
|
prompt: `{{SELECTED_CONTENT}}\n\n\n{{INPUT_CONTENT}}`
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'explain',
|
||||||
|
label: $i18n.t('Explain'),
|
||||||
|
icon: LightBulb,
|
||||||
|
prompt: `{{SELECTED_CONTENT}}\n\n\n${$i18n.t('Explain')}`
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
const autoScroll = async () => {
|
const autoScroll = async () => {
|
||||||
// Scroll to bottom only if the scroll is at the bottom give 50px buffer
|
|
||||||
const responseContainer = document.getElementById('response-container');
|
const responseContainer = document.getElementById('response-container');
|
||||||
|
if (responseContainer) {
|
||||||
|
// Scroll to bottom only if the scroll is at the bottom give 50px buffer
|
||||||
if (
|
if (
|
||||||
responseContainer.scrollHeight - responseContainer.clientHeight <=
|
responseContainer.scrollHeight - responseContainer.clientHeight <=
|
||||||
responseContainer.scrollTop + 50
|
responseContainer.scrollTop + 50
|
||||||
) {
|
) {
|
||||||
responseContainer.scrollTop = responseContainer.scrollHeight;
|
responseContainer.scrollTop = responseContainer.scrollHeight;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const askHandler = async () => {
|
const actionHandler = async (actionId) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
toast.error('Model not selected');
|
toast.error('Model not selected');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
prompt = [
|
|
||||||
// Blockquote each line of the selected text
|
|
||||||
...selectedText.split('\n').map((line) => `> ${line}`),
|
|
||||||
'',
|
|
||||||
// Then your question
|
|
||||||
floatingInputValue
|
|
||||||
].join('\n');
|
|
||||||
floatingInputValue = '';
|
|
||||||
|
|
||||||
responseContent = '';
|
let selectedContent = selectedText
|
||||||
const [res, controller] = await chatCompletion(localStorage.token, {
|
|
||||||
model: model,
|
|
||||||
messages: [
|
|
||||||
...messages,
|
|
||||||
{
|
|
||||||
role: 'user',
|
|
||||||
content: prompt
|
|
||||||
}
|
|
||||||
].map((message) => ({
|
|
||||||
role: message.role,
|
|
||||||
content: message.content
|
|
||||||
})),
|
|
||||||
stream: true // Enable streaming
|
|
||||||
});
|
|
||||||
|
|
||||||
if (res && res.ok) {
|
|
||||||
const reader = res.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
|
|
||||||
const processStream = async () => {
|
|
||||||
while (true) {
|
|
||||||
// Read data chunks from the response stream
|
|
||||||
const { done, value } = await reader.read();
|
|
||||||
if (done) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode the received chunk
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
// Process lines within the chunk
|
|
||||||
const lines = chunk.split('\n').filter((line) => line.trim() !== '');
|
|
||||||
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith('data: ')) {
|
|
||||||
if (line.startsWith('data: [DONE]')) {
|
|
||||||
responseDone = true;
|
|
||||||
|
|
||||||
await tick();
|
|
||||||
autoScroll();
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
// Parse the JSON chunk
|
|
||||||
try {
|
|
||||||
const data = JSON.parse(line.slice(6));
|
|
||||||
|
|
||||||
// Append the `content` field from the "choices" object
|
|
||||||
if (data.choices && data.choices[0]?.delta?.content) {
|
|
||||||
responseContent += data.choices[0].delta.content;
|
|
||||||
|
|
||||||
autoScroll();
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Process the stream in the background
|
|
||||||
await processStream();
|
|
||||||
} else {
|
|
||||||
toast.error('An error occurred while fetching the explanation');
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const explainHandler = async () => {
|
|
||||||
if (!model) {
|
|
||||||
toast.error('Model not selected');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const quotedText = selectedText
|
|
||||||
.split('\n')
|
.split('\n')
|
||||||
.map((line) => `> ${line}`)
|
.map((line) => `> ${line}`)
|
||||||
.join('\n');
|
.join('\n');
|
||||||
prompt = `${quotedText}\n\nExplain`;
|
|
||||||
|
|
||||||
|
let selectedAction = actions.find((action) => action.id === actionId);
|
||||||
|
if (!selectedAction) {
|
||||||
|
toast.error('Action not found');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let prompt = selectedAction?.prompt ?? '';
|
||||||
|
let toolIds = [];
|
||||||
|
|
||||||
|
// Handle: {{variableId|tool:id="toolId"}} pattern
|
||||||
|
// This regex captures variableId and toolId from {{variableId|tool:id="toolId"}}
|
||||||
|
const varToolPattern = /\{\{(.*?)\|tool:id="([^"]+)"\}\}/g;
|
||||||
|
prompt = prompt.replace(varToolPattern, (match, variableId, toolId) => {
|
||||||
|
toolIds.push(toolId);
|
||||||
|
return variableId; // Replace with just variableId
|
||||||
|
});
|
||||||
|
|
||||||
|
// legacy {{TOOL:toolId}} pattern (for backward compatibility)
|
||||||
|
let toolIdPattern = /\{\{TOOL:([^\}]+)\}\}/g;
|
||||||
|
let match;
|
||||||
|
while ((match = toolIdPattern.exec(prompt)) !== null) {
|
||||||
|
toolIds.push(match[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all TOOL placeholders from the prompt
|
||||||
|
prompt = prompt.replace(toolIdPattern, '');
|
||||||
|
|
||||||
|
if (prompt.includes('{{INPUT_CONTENT}}') && !floatingInput) {
|
||||||
|
prompt = prompt.replace('{{INPUT_CONTENT}}', floatingInputValue);
|
||||||
|
floatingInputValue = '';
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = prompt.replace('{{CONTENT}}', selectedText);
|
||||||
|
prompt = prompt.replace('{{SELECTED_CONTENT}}', selectedContent);
|
||||||
|
|
||||||
|
content = prompt;
|
||||||
responseContent = '';
|
responseContent = '';
|
||||||
const [res, controller] = await chatCompletion(localStorage.token, {
|
|
||||||
|
let res;
|
||||||
|
[res, controller] = await chatCompletion(localStorage.token, {
|
||||||
model: model,
|
model: model,
|
||||||
messages: [
|
messages: [
|
||||||
...messages,
|
...messages,
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: prompt
|
content: content
|
||||||
}
|
}
|
||||||
].map((message) => ({
|
].map((message) => ({
|
||||||
role: message.role,
|
role: message.role,
|
||||||
content: message.content
|
content: message.content
|
||||||
})),
|
})),
|
||||||
|
...(toolIds.length > 0
|
||||||
|
? {
|
||||||
|
tool_ids: toolIds
|
||||||
|
// params: {
|
||||||
|
// function_calling: 'native'
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
|
|
||||||
stream: true // Enable streaming
|
stream: true // Enable streaming
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -196,7 +185,13 @@
|
||||||
};
|
};
|
||||||
|
|
||||||
// Process the stream in the background
|
// Process the stream in the background
|
||||||
|
try {
|
||||||
await processStream();
|
await processStream();
|
||||||
|
} catch (e) {
|
||||||
|
if (e.name !== 'AbortError') {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
toast.error('An error occurred while fetching the explanation');
|
toast.error('An error occurred while fetching the explanation');
|
||||||
}
|
}
|
||||||
|
|
@ -206,7 +201,7 @@
|
||||||
const messages = [
|
const messages = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: prompt
|
content: content
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
|
|
@ -222,11 +217,23 @@
|
||||||
};
|
};
|
||||||
|
|
||||||
export const closeHandler = () => {
|
export const closeHandler = () => {
|
||||||
|
if (controller) {
|
||||||
|
controller.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedAction = null;
|
||||||
|
selectedText = '';
|
||||||
responseContent = null;
|
responseContent = null;
|
||||||
responseDone = false;
|
responseDone = false;
|
||||||
floatingInput = false;
|
floatingInput = false;
|
||||||
floatingInputValue = '';
|
floatingInputValue = '';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
onDestroy(() => {
|
||||||
|
if (controller) {
|
||||||
|
controller.abort();
|
||||||
|
}
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
|
|
@ -239,11 +246,16 @@
|
||||||
<div
|
<div
|
||||||
class="flex flex-row gap-0.5 shrink-0 p-1 bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-lg shadow-xl"
|
class="flex flex-row gap-0.5 shrink-0 p-1 bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-lg shadow-xl"
|
||||||
>
|
>
|
||||||
|
{#each actions as action}
|
||||||
<button
|
<button
|
||||||
class="px-1 hover:bg-gray-50 dark:hover:bg-gray-800 rounded-sm flex items-center gap-1 min-w-fit"
|
class="px-1 hover:bg-gray-50 dark:hover:bg-gray-800 rounded-sm flex items-center gap-1 min-w-fit"
|
||||||
on:click={async () => {
|
on:click={async () => {
|
||||||
selectedText = window.getSelection().toString();
|
selectedText = window.getSelection().toString();
|
||||||
|
selectedAction = action;
|
||||||
|
|
||||||
|
if (action.prompt.includes('{{INPUT_CONTENT}}')) {
|
||||||
floatingInput = true;
|
floatingInput = true;
|
||||||
|
floatingInputValue = '';
|
||||||
|
|
||||||
await tick();
|
await tick();
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
|
|
@ -252,23 +264,17 @@
|
||||||
input.focus();
|
input.focus();
|
||||||
}
|
}
|
||||||
}, 0);
|
}, 0);
|
||||||
|
} else {
|
||||||
|
actionHandler(action.id);
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ChatBubble className="size-3 shrink-0" />
|
{#if action.icon}
|
||||||
|
<svelte:component this={action.icon} className="size-3 shrink-0" />
|
||||||
<div class="shrink-0">{$i18n.t('Ask')}</div>
|
{/if}
|
||||||
</button>
|
<div class="shrink-0">{action.label}</div>
|
||||||
<button
|
|
||||||
class="px-1 hover:bg-gray-50 dark:hover:bg-gray-800 rounded-sm flex items-center gap-1 min-w-fit"
|
|
||||||
on:click={() => {
|
|
||||||
selectedText = window.getSelection().toString();
|
|
||||||
explainHandler();
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<LightBulb className="size-3 shrink-0" />
|
|
||||||
|
|
||||||
<div class="shrink-0">{$i18n.t('Explain')}</div>
|
|
||||||
</button>
|
</button>
|
||||||
|
{/each}
|
||||||
</div>
|
</div>
|
||||||
{:else}
|
{:else}
|
||||||
<div
|
<div
|
||||||
|
|
@ -282,7 +288,7 @@
|
||||||
bind:value={floatingInputValue}
|
bind:value={floatingInputValue}
|
||||||
on:keydown={(e) => {
|
on:keydown={(e) => {
|
||||||
if (e.key === 'Enter') {
|
if (e.key === 'Enter') {
|
||||||
askHandler();
|
actionHandler(selectedAction?.id);
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
@ -293,7 +299,7 @@
|
||||||
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
|
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
|
||||||
: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 m-0.5 self-center"
|
: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 m-0.5 self-center"
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
askHandler();
|
actionHandler(selectedAction?.id);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<svg
|
<svg
|
||||||
|
|
@ -318,7 +324,7 @@
|
||||||
class="bg-gray-50/50 dark:bg-gray-800 dark:text-gray-100 text-medium rounded-xl px-3.5 py-3 w-full"
|
class="bg-gray-50/50 dark:bg-gray-800 dark:text-gray-100 text-medium rounded-xl px-3.5 py-3 w-full"
|
||||||
>
|
>
|
||||||
<div class="font-medium">
|
<div class="font-medium">
|
||||||
<Markdown id={`${id}-float-prompt`} content={prompt} />
|
<Markdown id={`${id}-float-prompt`} {content} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
@ -326,7 +332,7 @@
|
||||||
class="bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-xl px-3.5 py-3 w-full"
|
class="bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-xl px-3.5 py-3 w-full"
|
||||||
>
|
>
|
||||||
<div class=" max-h-80 overflow-y-auto w-full markdown-prose-xs" id="response-container">
|
<div class=" max-h-80 overflow-y-auto w-full markdown-prose-xs" id="response-container">
|
||||||
{#if responseContent.trim() === ''}
|
{#if !responseContent || responseContent?.trim() === ''}
|
||||||
<Skeleton size="sm" />
|
<Skeleton size="sm" />
|
||||||
{:else}
|
{:else}
|
||||||
<Markdown id={`${id}-float-response`} content={responseContent} />
|
<Markdown id={`${id}-float-response`} content={responseContent} />
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{#if $user?.role === 'admin' || ($user?.permissions.chat?.controls ?? true)}
|
||||||
<div class=" dark:text-gray-200 text-sm font-primary py-0.5 px-0.5">
|
<div class=" dark:text-gray-200 text-sm font-primary py-0.5 px-0.5">
|
||||||
{#if chatFiles.length > 0}
|
{#if chatFiles.length > 0}
|
||||||
<Collapsible title={$i18n.t('Files')} open={true} buttonClassName="w-full">
|
<Collapsible title={$i18n.t('Files')} open={true} buttonClassName="w-full">
|
||||||
|
|
@ -61,15 +62,17 @@
|
||||||
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
|
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
|
{#if $user?.role === 'admin' || ($user?.permissions.chat?.valves ?? true)}
|
||||||
<Collapsible bind:open={showValves} title={$i18n.t('Valves')} buttonClassName="w-full">
|
<Collapsible bind:open={showValves} title={$i18n.t('Valves')} buttonClassName="w-full">
|
||||||
<div class="text-sm" slot="content">
|
<div class="text-sm" slot="content">
|
||||||
<Valves show={showValves} />
|
<Valves show={showValves} />
|
||||||
</div>
|
</div>
|
||||||
</Collapsible>
|
</Collapsible>
|
||||||
|
|
||||||
{#if $user?.role === 'admin' || ($user?.permissions.chat?.system_prompt ?? true)}
|
|
||||||
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
|
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
{#if $user?.role === 'admin' || ($user?.permissions.chat?.system_prompt ?? true)}
|
||||||
<Collapsible title={$i18n.t('System Prompt')} open={true} buttonClassName="w-full">
|
<Collapsible title={$i18n.t('System Prompt')} open={true} buttonClassName="w-full">
|
||||||
<div class="" slot="content">
|
<div class="" slot="content">
|
||||||
<textarea
|
<textarea
|
||||||
|
|
@ -82,11 +85,11 @@
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Collapsible>
|
</Collapsible>
|
||||||
|
|
||||||
|
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
{#if $user?.role === 'admin' || ($user?.permissions.chat?.controls ?? true)}
|
{#if $user?.role === 'admin' || ($user?.permissions.chat?.params ?? true)}
|
||||||
<hr class="my-2 border-gray-50 dark:border-gray-700/10" />
|
|
||||||
|
|
||||||
<Collapsible title={$i18n.t('Advanced Params')} open={true} buttonClassName="w-full">
|
<Collapsible title={$i18n.t('Advanced Params')} open={true} buttonClassName="w-full">
|
||||||
<div class="text-sm mt-1.5" slot="content">
|
<div class="text-sm mt-1.5" slot="content">
|
||||||
<div>
|
<div>
|
||||||
|
|
@ -96,4 +99,5 @@
|
||||||
</Collapsible>
|
</Collapsible>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -72,15 +72,15 @@
|
||||||
|
|
||||||
import { KokoroWorker } from '$lib/workers/KokoroWorker';
|
import { KokoroWorker } from '$lib/workers/KokoroWorker';
|
||||||
import InputVariablesModal from './MessageInput/InputVariablesModal.svelte';
|
import InputVariablesModal from './MessageInput/InputVariablesModal.svelte';
|
||||||
|
import Voice from '../icons/Voice.svelte';
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
export let transparentBackground = false;
|
|
||||||
|
|
||||||
export let onChange: Function = () => {};
|
export let onChange: Function = () => {};
|
||||||
export let createMessagePair: Function;
|
export let createMessagePair: Function;
|
||||||
export let stopResponse: Function;
|
export let stopResponse: Function;
|
||||||
|
|
||||||
export let autoScroll = false;
|
export let autoScroll = false;
|
||||||
|
export let generating = false;
|
||||||
|
|
||||||
export let atSelectedModel: Model | undefined = undefined;
|
export let atSelectedModel: Model | undefined = undefined;
|
||||||
export let selectedModels: [''];
|
export let selectedModels: [''];
|
||||||
|
|
@ -927,7 +927,7 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="{transparentBackground ? 'bg-transparent' : 'bg-white dark:bg-gray-900'} ">
|
<div class="bg-transparent">
|
||||||
<div
|
<div
|
||||||
class="{($settings?.widescreenMode ?? null)
|
class="{($settings?.widescreenMode ?? null)
|
||||||
? 'max-w-full'
|
? 'max-w-full'
|
||||||
|
|
@ -1084,6 +1084,7 @@
|
||||||
class="scrollbar-hidden rtl:text-right ltr:text-left bg-transparent dark:text-gray-100 outline-hidden w-full pt-2.5 pb-[5px] px-1 resize-none h-fit max-h-80 overflow-auto"
|
class="scrollbar-hidden rtl:text-right ltr:text-left bg-transparent dark:text-gray-100 outline-hidden w-full pt-2.5 pb-[5px] px-1 resize-none h-fit max-h-80 overflow-auto"
|
||||||
id="chat-input-container"
|
id="chat-input-container"
|
||||||
>
|
>
|
||||||
|
{#key $settings?.showFormattingToolbar ?? false}
|
||||||
<RichTextInput
|
<RichTextInput
|
||||||
bind:this={chatInputElement}
|
bind:this={chatInputElement}
|
||||||
id="chat-input"
|
id="chat-input"
|
||||||
|
|
@ -1093,7 +1094,8 @@
|
||||||
}}
|
}}
|
||||||
json={true}
|
json={true}
|
||||||
messageInput={true}
|
messageInput={true}
|
||||||
showFormattingButtons={false}
|
showFormattingToolbar={$settings?.showFormattingToolbar ?? false}
|
||||||
|
floatingMenuPlacement={'top-start'}
|
||||||
insertPromptAsRichText={$settings?.insertPromptAsRichText ?? false}
|
insertPromptAsRichText={$settings?.insertPromptAsRichText ?? false}
|
||||||
shiftEnter={!($settings?.ctrlEnterToSend ?? false) &&
|
shiftEnter={!($settings?.ctrlEnterToSend ?? false) &&
|
||||||
(!$mobile ||
|
(!$mobile ||
|
||||||
|
|
@ -1181,7 +1183,9 @@
|
||||||
commandsElement.selectUp();
|
commandsElement.selectUp();
|
||||||
|
|
||||||
const commandOptionButton = [
|
const commandOptionButton = [
|
||||||
...document.getElementsByClassName('selected-command-option-button')
|
...document.getElementsByClassName(
|
||||||
|
'selected-command-option-button'
|
||||||
|
)
|
||||||
]?.at(-1);
|
]?.at(-1);
|
||||||
commandOptionButton.scrollIntoView({ block: 'center' });
|
commandOptionButton.scrollIntoView({ block: 'center' });
|
||||||
}
|
}
|
||||||
|
|
@ -1191,7 +1195,9 @@
|
||||||
commandsElement.selectDown();
|
commandsElement.selectDown();
|
||||||
|
|
||||||
const commandOptionButton = [
|
const commandOptionButton = [
|
||||||
...document.getElementsByClassName('selected-command-option-button')
|
...document.getElementsByClassName(
|
||||||
|
'selected-command-option-button'
|
||||||
|
)
|
||||||
]?.at(-1);
|
]?.at(-1);
|
||||||
commandOptionButton.scrollIntoView({ block: 'center' });
|
commandOptionButton.scrollIntoView({ block: 'center' });
|
||||||
}
|
}
|
||||||
|
|
@ -1200,7 +1206,9 @@
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
const commandOptionButton = [
|
const commandOptionButton = [
|
||||||
...document.getElementsByClassName('selected-command-option-button')
|
...document.getElementsByClassName(
|
||||||
|
'selected-command-option-button'
|
||||||
|
)
|
||||||
]?.at(-1);
|
]?.at(-1);
|
||||||
|
|
||||||
commandOptionButton?.click();
|
commandOptionButton?.click();
|
||||||
|
|
@ -1210,7 +1218,9 @@
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
const commandOptionButton = [
|
const commandOptionButton = [
|
||||||
...document.getElementsByClassName('selected-command-option-button')
|
...document.getElementsByClassName(
|
||||||
|
'selected-command-option-button'
|
||||||
|
)
|
||||||
]?.at(-1);
|
]?.at(-1);
|
||||||
|
|
||||||
if (commandOptionButton) {
|
if (commandOptionButton) {
|
||||||
|
|
@ -1298,9 +1308,13 @@
|
||||||
if (text.length > PASTED_TEXT_CHARACTER_LIMIT) {
|
if (text.length > PASTED_TEXT_CHARACTER_LIMIT) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
const blob = new Blob([text], { type: 'text/plain' });
|
const blob = new Blob([text], { type: 'text/plain' });
|
||||||
const file = new File([blob], `Pasted_Text_${Date.now()}.txt`, {
|
const file = new File(
|
||||||
|
[blob],
|
||||||
|
`Pasted_Text_${Date.now()}.txt`,
|
||||||
|
{
|
||||||
type: 'text/plain'
|
type: 'text/plain'
|
||||||
});
|
}
|
||||||
|
);
|
||||||
|
|
||||||
await uploadFileHandler(file, true);
|
await uploadFileHandler(file, true);
|
||||||
}
|
}
|
||||||
|
|
@ -1310,13 +1324,14 @@
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
{/key}
|
||||||
</div>
|
</div>
|
||||||
{:else}
|
{:else}
|
||||||
<textarea
|
<textarea
|
||||||
id="chat-input"
|
id="chat-input"
|
||||||
dir={$settings?.chatDirection ?? 'auto'}
|
dir={$settings?.chatDirection ?? 'auto'}
|
||||||
bind:this={chatInputElement}
|
bind:this={chatInputElement}
|
||||||
class="scrollbar-hidden bg-transparent dark:text-gray-200 outline-hidden w-full pt-3 px-1 resize-none"
|
class="scrollbar-hidden bg-transparent dark:text-gray-200 outline-hidden w-full pt-4 pb-1 px-1 resize-none"
|
||||||
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
|
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
|
||||||
bind:value={prompt}
|
bind:value={prompt}
|
||||||
on:input={() => {
|
on:input={() => {
|
||||||
|
|
@ -1819,7 +1834,7 @@
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
{#if (taskIds && taskIds.length > 0) || (history.currentId && history.messages[history.currentId]?.done != true)}
|
{#if (taskIds && taskIds.length > 0) || (history.currentId && history.messages[history.currentId]?.done != true) || generating}
|
||||||
<div class=" flex items-center">
|
<div class=" flex items-center">
|
||||||
<Tooltip content={$i18n.t('Stop')}>
|
<Tooltip content={$i18n.t('Stop')}>
|
||||||
<button
|
<button
|
||||||
|
|
@ -1902,7 +1917,7 @@
|
||||||
}}
|
}}
|
||||||
aria-label={$i18n.t('Voice mode')}
|
aria-label={$i18n.t('Voice mode')}
|
||||||
>
|
>
|
||||||
<Headphone className="size-5" />
|
<Voice className="size-5" strokeWidth="2.5" />
|
||||||
</button>
|
</button>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@
|
||||||
|
|
||||||
<div slot="content">
|
<div slot="content">
|
||||||
<DropdownMenu.Content
|
<DropdownMenu.Content
|
||||||
class="w-full max-w-[200px] rounded-xl px-1 py-1 border border-gray-300/30 dark:border-gray-700/50 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-sm"
|
class="w-full max-w-[240px] rounded-xl px-1 py-1 border border-gray-300/30 dark:border-gray-700/50 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-sm"
|
||||||
sideOffset={10}
|
sideOffset={10}
|
||||||
alignOffset={-8}
|
alignOffset={-8}
|
||||||
side="top"
|
side="top"
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@
|
||||||
|
|
||||||
export let setInputText: Function = () => {};
|
export let setInputText: Function = () => {};
|
||||||
|
|
||||||
export let sendPrompt: Function;
|
export let sendMessage: Function;
|
||||||
export let continueResponse: Function;
|
export let continueResponse: Function;
|
||||||
export let regenerateResponse: Function;
|
export let regenerateResponse: Function;
|
||||||
export let mergeResponses: Function;
|
export let mergeResponses: Function;
|
||||||
|
|
@ -50,6 +50,7 @@
|
||||||
|
|
||||||
export let readOnly = false;
|
export let readOnly = false;
|
||||||
|
|
||||||
|
export let topPadding = false;
|
||||||
export let bottomPadding = false;
|
export let bottomPadding = false;
|
||||||
export let autoScroll;
|
export let autoScroll;
|
||||||
|
|
||||||
|
|
@ -294,7 +295,7 @@
|
||||||
history.currentId = userMessageId;
|
history.currentId = userMessageId;
|
||||||
|
|
||||||
await tick();
|
await tick();
|
||||||
await sendPrompt(history, userPrompt, userMessageId);
|
await sendMessage(history, userMessageId);
|
||||||
} else {
|
} else {
|
||||||
// Edit user message
|
// Edit user message
|
||||||
history.messages[messageId].content = content;
|
history.messages[messageId].content = content;
|
||||||
|
|
@ -445,6 +446,7 @@
|
||||||
{addMessages}
|
{addMessages}
|
||||||
{triggerScroll}
|
{triggerScroll}
|
||||||
{readOnly}
|
{readOnly}
|
||||||
|
{topPadding}
|
||||||
/>
|
/>
|
||||||
{/each}
|
{/each}
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
import { WEBUI_API_BASE_URL } from '$lib/constants';
|
import { WEBUI_API_BASE_URL } from '$lib/constants';
|
||||||
|
|
||||||
import XMark from '$lib/components/icons/XMark.svelte';
|
import XMark from '$lib/components/icons/XMark.svelte';
|
||||||
|
import Textarea from '$lib/components/common/Textarea.svelte';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
|
|
@ -111,15 +112,12 @@
|
||||||
</div>
|
</div>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
{#if document.metadata?.parameters}
|
{#if document.metadata?.parameters}
|
||||||
<div class="text-sm font-medium dark:text-gray-300 mt-2">
|
<div class="text-sm font-medium dark:text-gray-300 mt-2 mb-0.5">
|
||||||
{$i18n.t('Parameters')}
|
{$i18n.t('Parameters')}
|
||||||
</div>
|
</div>
|
||||||
<pre
|
|
||||||
class="text-sm dark:text-gray-400 bg-gray-50 dark:bg-gray-800 p-2 rounded-md overflow-auto max-h-40">{JSON.stringify(
|
<Textarea readonly value={JSON.stringify(document.metadata.parameters, null, 2)}
|
||||||
document.metadata.parameters,
|
></Textarea>
|
||||||
null,
|
|
||||||
2
|
|
||||||
)}</pre>
|
|
||||||
{/if}
|
{/if}
|
||||||
{#if showRelevance}
|
{#if showRelevance}
|
||||||
<div class="text-sm font-medium dark:text-gray-300 mt-2">
|
<div class="text-sm font-medium dark:text-gray-300 mt-2">
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue