mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-24 10:15:22 +00:00
Compare commits
No commits in common. "main" and "v0.6.36" have entirely different histories.
434 changed files with 14066 additions and 30663 deletions
3
.github/workflows/deploy-to-hf-spaces.yml
vendored
3
.github/workflows/deploy-to-hf-spaces.yml
vendored
|
|
@ -57,8 +57,7 @@ jobs:
|
||||||
git lfs install
|
git lfs install
|
||||||
git lfs track "*.ttf"
|
git lfs track "*.ttf"
|
||||||
git lfs track "*.jpg"
|
git lfs track "*.jpg"
|
||||||
rm demo.png
|
rm demo.gif
|
||||||
rm banner.png
|
|
||||||
git add .
|
git add .
|
||||||
git commit -m "GitHub deploy: ${{ github.sha }}"
|
git commit -m "GitHub deploy: ${{ github.sha }}"
|
||||||
git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main
|
git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main
|
||||||
|
|
|
||||||
6
.github/workflows/docker-build.yaml
vendored
6
.github/workflows/docker-build.yaml
vendored
|
|
@ -141,9 +141,6 @@ jobs:
|
||||||
platform=${{ matrix.platform }}
|
platform=${{ matrix.platform }}
|
||||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Delete huge unnecessary tools folder
|
|
||||||
run: rm -rf /opt/hostedtoolcache
|
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
|
@ -246,9 +243,6 @@ jobs:
|
||||||
platform=${{ matrix.platform }}
|
platform=${{ matrix.platform }}
|
||||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Delete huge unnecessary tools folder
|
|
||||||
run: rm -rf /opt/hostedtoolcache
|
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ pnpm-lock.yaml
|
||||||
package-lock.json
|
package-lock.json
|
||||||
yarn.lock
|
yarn.lock
|
||||||
|
|
||||||
|
kubernetes/
|
||||||
|
|
||||||
# Copy of .gitignore
|
# Copy of .gitignore
|
||||||
.DS_Store
|
.DS_Store
|
||||||
node_modules
|
node_modules
|
||||||
|
|
|
||||||
314
CHANGELOG.md
314
CHANGELOG.md
|
|
@ -5,320 +5,6 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [0.6.43] - 2025-12-22
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- 🐍 **Python dependency installation issues** were resolved by correcting pip dependency handling, preventing installation failures in certain environments and improving setup reliability. [Commit](https://github.com/open-webui/open-webui/commit/5c5f87a)
|
|
||||||
- 🎙️ **Speech-to-Text default content type handling** was fixed and refactored to ensure correct MIME type usage, improving compatibility across STT providers and preventing transcription errors caused by incorrect defaults. [Commit](https://github.com/open-webui/open-webui/commit/4ab917c)
|
|
||||||
- 🖼️ **Temporary chat image handling** was fixed and refactored, ensuring images generated or edited in temporary chats are correctly processed, stored, and displayed without inconsistencies or missing references. [Commit](https://github.com/open-webui/open-webui/commit/423983f)
|
|
||||||
- 🎨 **Image action button fixed**, restoring the ability to trigger image generation, editing, and related image actions from the chat UI. [Commit](https://github.com/open-webui/open-webui/commit/def8a00)
|
|
||||||
|
|
||||||
## [0.6.42] - 2025-12-21
|
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- 📚 Knowledge base file management was overhauled with server-side pagination loading 30 files at a time instead of loading entire collections at once, dramatically improving performance and responsiveness for large knowledge bases with hundreds or thousands of files, reducing initial load times and memory usage while adding server-side search and filtering, view options for files added by the user versus shared files, customizable sorting by name or date, and file authorship tracking with upload timestamps. [Commit](https://github.com/open-webui/open-webui/commit/94a8439105f30203ea9d729787c9c5978f5c22a2)
|
|
||||||
- ✨ Knowledge base file management was enhanced with automatic list refresh after file operations ensuring immediate UI updates, improved permission validation at the model layer, and automatic channel-file association for files uploaded with channel metadata. [Commit](https://github.com/open-webui/open-webui/commit/c15201620d03a9b60b800a34d8dc3426722c5b8b)
|
|
||||||
- 🔎 Knowledge command in chat input now uses server-side search for massive performance increases when selecting knowledge bases and files. [Commit](https://github.com/open-webui/open-webui/commit/0addc1ea461d7b4eee8fe0ca2fedd615b3988b0e)
|
|
||||||
- 🗂️ Knowledge workspace listing now uses server-side pagination loading 30 collections at a time with new search endpoints supporting query filtering and view options for created versus shared collections. [Commit](https://github.com/open-webui/open-webui/commit/ceae3d48e603f53313d5483abe94099e20e914e8)
|
|
||||||
- 📖 Knowledge workspace now displays all collections with read access including shared read-only collections, enabling users to discover and explore knowledge bases they don't own while maintaining proper access controls through visual "Read Only" badges and automatically disabled editing controls for name, description, file uploads, content editing, and deletion operations. [Commit](https://github.com/open-webui/open-webui/commit/693636d971d0e8398fa0c9ec3897686750007af5)
|
|
||||||
- 📁 Bulk website and YouTube video attachment now supports adding multiple URLs at once (newline-separated) with automatic YouTube detection and transcript retrieval, processed sequentially to prevent resource strain, and both websites and videos can now be added directly to knowledge bases through the workspace UI. [Commit](https://github.com/open-webui/open-webui/commit/7746e9f4b831f09953ad2b659b96e0fd52911031), [#6202](https://github.com/open-webui/open-webui/issues/6202), [#19587](https://github.com/open-webui/open-webui/pull/19587), [#8231](https://github.com/open-webui/open-webui/pull/8231)
|
|
||||||
- 🪟 Sidebar width is now resizable on desktop devices with persistent storage in localStorage, enforcing minimum and maximum width constraints (220px to 480px) while all layout components now reference the dynamic sidebar width via CSS variables for consistent responsive behavior. [Commit](https://github.com/open-webui/open-webui/commit/b364cf43d3e8fd3557f65f17bc285bfaca5ed368)
|
|
||||||
- 📝 Notes feature now supports server-side search and filtering with view options for notes created by the user versus notes shared with them, customizable sorting by name or date in both list and grid view modes within a redesigned interface featuring consolidated note management controls in a unified header, group-based permission sharing with read, write, and read-only access control displaying note authorship and sharing status for better collaboration, and paginated infinite scroll for improved performance with large note collections. [Commit](https://github.com/open-webui/open-webui/commit/9b24cddef6c4862bd899eb8d6332cafff54e871d)
|
|
||||||
- 👁️ Notes now support read-only access permissions, allowing users to share notes for viewing without granting edit rights, with the editor automatically becoming non-editable and appropriate UI indicators when read-only access is detected. [Commit](https://github.com/open-webui/open-webui/commit/4363df175d50e0f9729381ac2ba9b37a3c3a966d)
|
|
||||||
- 📄 Notes can now be created directly from the chat input field, allowing users to save drafted messages or content as notes without navigation or retyping. [Commit](https://github.com/open-webui/open-webui/commit/00c2b6ca405d617e3d7520953a00a36c19c790ec)
|
|
||||||
- 🪟 Sidebar folders, channels, and pinned models sections now automatically expand when creating new items or pinning models, providing immediate visual feedback for user actions. [Commit](https://github.com/open-webui/open-webui/commit/f826d3ed75213a0a1b31b50d030bfb1d5e91d199), [#19929](https://github.com/open-webui/open-webui/pull/19929)
|
|
||||||
- 📋 Chat file associations are now properly tracked in the database through a new "chat_file" table, enabling accurate file management across chats and ensuring proper cleanup of files when chats are deleted, while improving database consistency in multi-node deployments. [Commit](https://github.com/open-webui/open-webui/commit/f1bf4f20c53e6493f0eb6fa2f12cb84c2d22da52)
|
|
||||||
- 🖼️ User-uploaded images are now automatically converted from base64 to actual file storage on the server, eliminating large inline base64 strings from being stored in chat history and reducing message payload sizes while enabling better image management and sharing across multiple chats. [Commit](https://github.com/open-webui/open-webui/commit/f1bf4f20c53e6493f0eb6fa2f12cb84c2d22da52)
|
|
||||||
- 📸 Shared chats with generated or edited images now correctly display images when accessed by other users by properly linking generated images to their chat and message through the chat_file table, ensuring images remain accessible in shared chat links. [Commit](https://github.com/open-webui/open-webui/commit/446cc0ac6063402a743e949f50612376ed5a8437), [#19393](https://github.com/open-webui/open-webui/issues/19393)
|
|
||||||
- 📊 File viewer modal was significantly enhanced with native-like viewers for Excel/CSV spreadsheets rendering as interactive scrollable tables with multi-sheet navigation support, Markdown documents displaying with full typography including headers, lists, links, and tables, and source code files showing syntax highlighting, all accessible through a tabbed interface defaulting to raw text view. [#20035](https://github.com/open-webui/open-webui/pull/20035), [#2867](https://github.com/open-webui/open-webui/issues/2867)
|
|
||||||
- 📏 Chat input now displays an expand button in the top-right corner when messages exceed two lines, providing optional access to a full-screen editor for composing longer messages with enhanced workspace and visibility while temporarily disabling the main input to prevent editing conflicts. [Commit](https://github.com/open-webui/open-webui/commit/205c7111200c22da42e9b5fe1e676aec9cca6daa)
|
|
||||||
- 💬 Channel message data lazy loading was implemented, deferring attachment and file metadata retrieval until needed to improve initial message list load performance. [Commit](https://github.com/open-webui/open-webui/commit/54b7ec56d6bcd2d79addc1694b757dab18cf18c5)
|
|
||||||
- 🖼️ Channel image upload handling was optimized to process and store compressed images directly as files rather than inline data, improving memory efficiency and message load times. [Commit](https://github.com/open-webui/open-webui/commit/22f1b764a7ea1add0a896906a9ef00b4b6743adc)
|
|
||||||
- 🎥 Video file playback support was added to channel messages, enabling inline video viewing with native player controls. [Commit](https://github.com/open-webui/open-webui/commit/7b126b23d50a0bd36a350fe09dc1dbe3df105318)
|
|
||||||
- 🔐 LDAP authentication now supports user entries with multiple username attributes, correctly handling cases where the username field contains a list of values. [Commit](https://github.com/open-webui/open-webui/commit/379f888c9dc6dce21c3ef7a1fc455258aff993dc), [#19878](https://github.com/open-webui/open-webui/issues/19878)
|
|
||||||
- 👨👩👧👦 The "ENABLE_PUBLIC_ACTIVE_USERS_COUNT" environment variable now allows restricting active user count visibility to administrators, reducing backend load and addressing privacy concerns in large deployments. [#20027](https://github.com/open-webui/open-webui/pull/20027), [#13026](https://github.com/open-webui/open-webui/issues/13026)
|
|
||||||
- 🚀 Models page search input performance was optimized with a 300ms debounce to reduce server load and improve responsiveness. [#19832](https://github.com/open-webui/open-webui/pull/19832)
|
|
||||||
- 💨 Frontend performance was optimized by preventing unnecessary API calls for API Keys and Channels features when they are disabled in admin settings, reducing backend noise and improving overall system efficiency. [#20043](https://github.com/open-webui/open-webui/pull/20043), [#19967](https://github.com/open-webui/open-webui/issues/19967)
|
|
||||||
- 📎 Channel file association tracking was implemented, automatically linking uploaded files to their respective channels with a dedicated association table enabling better organization and future file management features within channels. [Commit](https://github.com/open-webui/open-webui/commit/2bccf8350d0915f69b8020934bb179c52e81b7b5)
|
|
||||||
- 👥 User profile previews now display group membership information for easier identification of user roles and permissions. [Commit](https://github.com/open-webui/open-webui/commit/2b1a29d44bde9fbc20ff9f0a5ded1ce8ded9d90d)
|
|
||||||
- 🌍 The "SEARXNG_LANGUAGE" environment variable now allows configuring search language for SearXNG queries, replacing the hardcoded "en-US" default with a configurable setting that defaults to "all". [#19909](https://github.com/open-webui/open-webui/pull/19909)
|
|
||||||
- ⏳ The "MINERU_API_TIMEOUT" environment variable now allows configuring request timeouts for MinerU document processing operations. [#20016](https://github.com/open-webui/open-webui/pull/20016), [#18495](https://github.com/open-webui/open-webui/issues/18495)
|
|
||||||
- 🔧 The "RAG_EXTERNAL_RERANKER_TIMEOUT" environment variable now allows configuring request timeouts for external reranker operations. [#20049](https://github.com/open-webui/open-webui/pull/20049), [#19900](https://github.com/open-webui/open-webui/issues/19900)
|
|
||||||
- 🎨 OpenAI GPT-IMAGE 1.5 model support was added for image generation and editing with automatic image size capabilities. [Commit](https://github.com/open-webui/open-webui/commit/4c2e5c93e9287479f56f780708656136849ccaee)
|
|
||||||
- 🔑 The "OAUTH_AUDIENCE" environment variable now allows OAuth providers to specify audience parameters for JWT access token generation. [#19768](https://github.com/open-webui/open-webui/pull/19768)
|
|
||||||
- ⏰ The "REDIS_SOCKET_CONNECT_TIMEOUT" environment variable now allows configuring socket connection timeouts for Redis and Sentinel connections, addressing potential failover and responsiveness issues in distributed deployments. [#19799](https://github.com/open-webui/open-webui/pull/19799), [Docs:#882](https://github.com/open-webui/docs/pull/882)
|
|
||||||
- ⏱️ The "WEB_LOADER_TIMEOUT" environment variable now allows configuring request timeouts for SafeWebBaseLoader operations. [#19804](https://github.com/open-webui/open-webui/pull/19804), [#19734](https://github.com/open-webui/open-webui/issues/19734)
|
|
||||||
- 🚀 Models API endpoint performance was optimized through batched model loading, eliminating N+1 queries and significantly reducing response times when filtering models by user permissions. [Commit](https://github.com/open-webui/open-webui/commit/0dd2cfe1f273fbacdbe90300a97c021f2e678656)
|
|
||||||
- 🔀 Custom model fallback handling was added, allowing workspace-created custom models to automatically fall back to the default chat model when their configured base model is not found; set "ENABLE_CUSTOM_MODEL_FALLBACK" to true to enable, preventing workflow disruption when base models are removed or renamed, while ensuring other requests remain unaffected. [Commit](https://github.com/open-webui/open-webui/commit/b35aeb8f46e0e278c6f4538382c2b6838e24cc5a), [#19985](https://github.com/open-webui/open-webui/pull/19985)
|
|
||||||
- 📡 A new /feedbacks/all/ids API endpoint was added to return only feedback IDs without metadata, significantly improving performance for external integrations working with large feedback collections. [Commit](https://github.com/open-webui/open-webui/commit/53c1ca64b7205d85f6de06bd69e3e265d15546b8)
|
|
||||||
- 📈 An experimental chat usage statistics endpoint (GET /api/v1/chats/stats/usage) was added with pagination support (50 chats per page) and comprehensive per-chat analytics including model usage counts, user and assistant message breakdowns, average response times calculated from message timestamps, average content lengths, and last activity timestamps; this endpoint remains experimental and not suitable for production use as it performs intensive calculations by processing entire message histories for each chat without caching. [Commit](https://github.com/open-webui/open-webui/commit/a7993f6f4e4591cd2aaa4718ece9e5623557d019)
|
|
||||||
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
|
||||||
- 🌐 Translations for German, Danish, Finnish, Korean, Portuguese (Brazil), Simplified Chinese, Traditional Chinese, Catalan, and Spanish were enhanced and expanded.
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- ⚡ External reranker operations were optimized to prevent event loop blocking by offloading synchronous HTTP requests to a thread pool using asyncio.to_thread(), eliminating application freezes during RAG reranking queries. [#20049](https://github.com/open-webui/open-webui/pull/20049), [#19900](https://github.com/open-webui/open-webui/issues/19900)
|
|
||||||
- 💭 Text loss in the explanation feature when using the "CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE" environment variable was resolved by correcting newline handling in streaming responses. [#19829](https://github.com/open-webui/open-webui/pull/19829)
|
|
||||||
- 📚 Knowledge base batch file addition failures caused by Pydantic validation errors are now prevented by making the meta field optional in file metadata responses, allowing files without metadata to be processed correctly. [#20022](https://github.com/open-webui/open-webui/pull/20022), [#14220](https://github.com/open-webui/open-webui/issues/14220)
|
|
||||||
- 🗄️ PostgreSQL null byte insertion failures when attaching web pages or processing embedded content are now prevented by consolidating text sanitization logic across chat messages, web search results, and knowledge base documents, removing null bytes and invalid UTF-8 surrogates before database insertion. [#20072](https://github.com/open-webui/open-webui/pull/20072), [#19867](https://github.com/open-webui/open-webui/issues/19867), [#18201](https://github.com/open-webui/open-webui/issues/18201), [#15616](https://github.com/open-webui/open-webui/issues/15616)
|
|
||||||
- 🎫 MCP OAuth 2.1 token exchange failures are now fixed by removing duplicate credential passing that caused "ID1,ID1" concatenation and 401 errors from the token endpoint. [#20076](https://github.com/open-webui/open-webui/pull/20076), [#19823](https://github.com/open-webui/open-webui/issues/19823)
|
|
||||||
- 📝 Notes "Improve" action now works correctly after the streaming API change in v0.6.41 by ensuring uploaded files are fully retrieved with complete metadata before processing, restoring note improvement and summarization functionality. [Commit](https://github.com/open-webui/open-webui/commit/a3458f492c53a3b00405f59fbe1ea953fe364f18), [#20078](https://github.com/open-webui/open-webui/discussions/20078)
|
|
||||||
- 🔑 MCP OAuth 2.1 tool servers now work correctly in multi-node deployments through lazy-loading of OAuth clients from Redis-synced configuration, eliminating 404 errors when load balancers route requests to nodes that didn't process the original config update. [#20076](https://github.com/open-webui/open-webui/pull/20076), [#19902](https://github.com/open-webui/open-webui/pull/19902), [#19901](https://github.com/open-webui/open-webui/issues/19901)
|
|
||||||
- 🧩 Chat loading failures when channels permissions were disabled are now prevented through graceful error handling. [Commit](https://github.com/open-webui/open-webui/commit/5c2df97f04cce5cb7087d288f816f91a739688c1)
|
|
||||||
- 🔍 Search bar freezing and crashing issues in Models, Chat, and Archived Chat pages caused by excessively long queries exceeding server URL limits were resolved by truncating queries to 500 characters, and knowledge base layout shifting with long names was fixed by adjusting flex container properties. [#19832](https://github.com/open-webui/open-webui/pull/19832)
|
|
||||||
- 🎛️ Rate limiting errors (HTTP 429) with Brave Search free tier when generating multiple queries are now prevented through asyncio.Semaphore-based concurrency control applied globally to all search engines. [#20070](https://github.com/open-webui/open-webui/pull/20070), [#20003](https://github.com/open-webui/open-webui/issues/20003), [#14107](https://github.com/open-webui/open-webui/issues/14107), [#15134](https://github.com/open-webui/open-webui/issues/15134)
|
|
||||||
- 💥 UI crashes and white screen errors caused by null chat lists during loading or network failures were prevented by adding null safety checks to chat iteration in folder placeholders and archived chat modals. [#19898](https://github.com/open-webui/open-webui/pull/19898)
|
|
||||||
- 🧩 Chat overview tab crashes caused by undefined model references were resolved by adding proper null checks when accessing deleted or ejected models. [#19935](https://github.com/open-webui/open-webui/pull/19935)
|
|
||||||
- 🔄 MultiResponseMessages component crashes when navigating chat history after removing or changing selected models are now prevented through proper component re-initialization. [Commit](https://github.com/open-webui/open-webui/commit/870e29e3738da968c396b70532f365a3c2f71995), [#18599](https://github.com/open-webui/open-webui/issues/18599)
|
|
||||||
- 🚫 Channel API endpoint access is now correctly blocked when channels are globally disabled, preventing users with channel permissions from accessing channel data via API requests when the feature is turned off in admin settings. [#19957](https://github.com/open-webui/open-webui/pull/19957), [#19914](https://github.com/open-webui/open-webui/issues/19914)
|
|
||||||
- 👤 User list popup display in the admin panel was fixed to correctly track user identity when sorting or filtering changes the list order, preventing popups from showing incorrect user information. [Commit](https://github.com/open-webui/open-webui/commit/ae47101dc6aef2c7d8ae0d843985341fff820057), [#20046](https://github.com/open-webui/open-webui/issues/20046)
|
|
||||||
- 👥 User selection in the "Edit User Group" modal now preserves pagination position, allowing administrators to select multiple users across pages without resetting to page 1. [#19959](https://github.com/open-webui/open-webui/pull/19959)
|
|
||||||
- 📸 Model avatar images now update immediately in the admin models list through proper Cache-Control headers, eliminating the need for manual cache clearing. [#19959](https://github.com/open-webui/open-webui/pull/19959)
|
|
||||||
- 🔒 Temporary chat permission enforcement now correctly prevents users from enabling the feature through personal settings when disabled in default or group permissions. [#19785](https://github.com/open-webui/open-webui/issues/19785)
|
|
||||||
- 🎨 Image editing with reference images now correctly uses both previously generated images and newly uploaded reference images. [Commit](https://github.com/open-webui/open-webui/commit/bcd50ed8f1b7387fd700538ae0d74fc72f3c53d0)
|
|
||||||
- 🧠 Image generation and editing operations are now explicitly injected into system context, improving LLM comprehension even for weaker models so they reliably acknowledge operations instead of incorrectly claiming they cannot generate images. [Commit](https://github.com/open-webui/open-webui/commit/28b2fcab0cd036dbe646a66fe81890f288c77121)
|
|
||||||
- 📑 Source citation rendering errors when citation syntax appeared in user messages or contexts without source data were resolved. [Commit](https://github.com/open-webui/open-webui/commit/3c8f1cf8e58d52e86375634b0381374298b1b4f3)
|
|
||||||
- 📄 DOCX file parsing now works correctly in temporary chats through client-side text extraction, preventing raw data from being displayed. [Commit](https://github.com/open-webui/open-webui/commit/6993b0b40b10af8cdbe6626702cc94080fff9e22)
|
|
||||||
- 🔧 Pipeline settings save failures when valve properties contain null values are now handled correctly. [#19791](https://github.com/open-webui/open-webui/pull/19791)
|
|
||||||
- ⚙️ Model usage settings are now correctly preserved when switching between models instead of being unexpectedly cleared or reset. [#19868](https://github.com/open-webui/open-webui/pull/19868), [#19549](https://github.com/open-webui/open-webui/issues/19549)
|
|
||||||
- 🛡️ Invalid PASSWORD_VALIDATION_REGEX_PATTERN configurations no longer cause startup warnings, with automatic fallback to the default pattern when regex compilation fails. [#20058](https://github.com/open-webui/open-webui/pull/20058)
|
|
||||||
- 🎯 The DefaultFiltersSelector component in model settings now correctly displays when only global toggleable filters are present, enabling per-model default configuration. [#20066](https://github.com/open-webui/open-webui/pull/20066)
|
|
||||||
- 🎤 Audio file upload failures caused by MIME type matching issues with spacing variations and codec parameters were resolved by implementing proper MIME type parsing. [#17771](https://github.com/open-webui/open-webui/pull/17771), [#17761](https://github.com/open-webui/open-webui/issues/17761)
|
|
||||||
- ⌨️ Regenerate response keyboard shortcut now only activates when chat input is selected, preventing unintended regeneration when modals are open or other UI elements are focused. [#19875](https://github.com/open-webui/open-webui/pull/19875)
|
|
||||||
- 📋 Log truncation issues in Docker deployments during application crashes were resolved by disabling Python stdio buffering, ensuring complete diagnostic output is captured. [#19844](https://github.com/open-webui/open-webui/issues/19844)
|
|
||||||
- 🔴 Redis cluster compatibility issues with disabled KEYS command were resolved by replacing blocking KEYS operations with production-safe SCAN iterations. [#19871](https://github.com/open-webui/open-webui/pull/19871), [#15834](https://github.com/open-webui/open-webui/issues/15834)
|
|
||||||
- 🔤 File attachment container layout issues when using RTL languages were resolved by applying chat direction settings to file containers across all message types. [#19891](https://github.com/open-webui/open-webui/pull/19891), [#19742](https://github.com/open-webui/open-webui/issues/19742)
|
|
||||||
- 🔃 Ollama model list now automatically refreshes after model deletion, preventing deleted models from persisting in the UI and being inadvertently re-downloaded during subsequent pull operations. [#19912](https://github.com/open-webui/open-webui/pull/19912)
|
|
||||||
- 🌐 Ollama Cloud web search now correctly applies domain filtering to search results. [Commit](https://github.com/open-webui/open-webui/commit/d4bd938a77c22409a1643c058b937a06e07baca9)
|
|
||||||
- 📜 Tool specification serialization now preserves non-ASCII characters including Chinese text, improving LLM comprehension and tool selection accuracy by avoiding Unicode escape sequences. [#19942](https://github.com/open-webui/open-webui/pull/19942)
|
|
||||||
- 🛟 Model editor stability was improved with null safety checks for tools, functions, and file input operations, preventing crashes when stores are undefined or file objects are invalid. [#19939](https://github.com/open-webui/open-webui/pull/19939)
|
|
||||||
- 🗣️ MoA completion handling stability was improved with null safety checks for response objects, boolean casting for settings, and proper timeout type definitions. [#19921](https://github.com/open-webui/open-webui/pull/19921)
|
|
||||||
- 🎛️ Chat functionality failures caused by empty logit_bias parameter values are now prevented by properly handling empty strings in the parameter parsing middleware. [#19982](https://github.com/open-webui/open-webui/issues/19982)
|
|
||||||
- 🔏 Administrators can now delete read-only knowledge bases from deleted users, resolving permission issues that previously prevented cleanup of orphaned read-only content. [Commit](https://github.com/open-webui/open-webui/commit/59d6eb2badf46f9c2b1e879484ac33432915b575)
|
|
||||||
- 💾 Cloned prompts and tools now correctly preserve their access control settings instead of being reset to null, preventing unintended visibility changes when duplicating private or restricted items. [#19960](https://github.com/open-webui/open-webui/pull/19960), [#19360](https://github.com/open-webui/open-webui/issues/19360)
|
|
||||||
- 🎚️ Text scale adjustment buttons in Interface Settings were fixed to correctly increment and decrement the scale value. [#19699](https://github.com/open-webui/open-webui/pull/19699)
|
|
||||||
- 🎭 Group channel invite button text visibility in light theme was corrected to display properly against dark backgrounds. [#19828](https://github.com/open-webui/open-webui/issues/19828)
|
|
||||||
- 📁 The move button is now hidden when no folders exist, preventing display of non-functional controls. [#19705](https://github.com/open-webui/open-webui/pull/19705)
|
|
||||||
- 📦 Qdrant client dependency was updated to resolve startup version incompatibility warnings. [#19757](https://github.com/open-webui/open-webui/pull/19757)
|
|
||||||
- 🧮 The "ENABLE_ASYNC_EMBEDDING" environment variable is now correctly applied to embedding operations when configured exclusively via environment variables. [#19748](https://github.com/open-webui/open-webui/pull/19748)
|
|
||||||
- 🌄 The "COMFYUI_WORKFLOW_NODES" and "IMAGES_EDIT_COMFYUI_WORKFLOW_NODES" environment variables are now correctly loaded and parsed as JSON lists, and the configuration key name was corrected from "COMFYUI_WORKFLOW" to "COMFYUI_WORKFLOW_NODES". [#19918](https://github.com/open-webui/open-webui/pull/19918), [#19886](https://github.com/open-webui/open-webui/issues/19886)
|
|
||||||
- 💫 Channel name length is now limited to 128 characters with validation to prevent display issues caused by excessively long names. [Commit](https://github.com/open-webui/open-webui/commit/f509f5542dde384d34402f6df763f49a06bea109)
|
|
||||||
- 🔐 Invalid PASSWORD_VALIDATION_REGEX_PATTERN configurations no longer cause startup warnings, with automatic fallback to the default pattern when regex compilation fails. [#20058](https://github.com/open-webui/open-webui/pull/20058)
|
|
||||||
- 🔎 Bocha search with filter list functionality now works correctly by returning results as a list instead of a dictionary wrapper, ensuring compatibility with result filtering operations. [Commit](https://github.com/open-webui/open-webui/commit/b5bd8704fe1672da839bb3be6210d7cb494797ce), [#19733](https://github.com/open-webui/open-webui/issues/19733)
|
|
||||||
|
|
||||||
### Changed
|
|
||||||
|
|
||||||
- ⚠️ This release includes database schema changes; multi-worker, multi-server, or load-balanced deployments must update all instances simultaneously rather than performing rolling updates, as running mixed versions will cause application failures due to schema incompatibility between old and new instances.
|
|
||||||
- 📡 WEB_SEARCH_CONCURRENT_REQUESTS default changed from 10 to 0 (unlimited) — This setting now applies to all search engines instead of only DuckDuckGo; previously users were implicitly limited to 10 concurrent queries, but now have unlimited parallel requests by default; set to 1 for sequential execution if using rate-limited APIs like Brave free tier. [#20070](https://github.com/open-webui/open-webui/pull/20070)
|
|
||||||
- 💾 SQLCipher absolute path handling was fixed to properly support absolute database paths (e.g., "/app/data.db") instead of incorrectly stripping leading slashes and converting them to relative paths; this restores functionality for Docker volume mounts and explicit absolute path configurations while maintaining backward compatibility with relative paths. [#20074](https://github.com/open-webui/open-webui/pull/20074)
|
|
||||||
- 🔌 Knowledge base file listing API was redesigned with paginated responses and new filtering parameters; the GET /knowledge/{id}/files endpoint now returns paginated results with user attribution instead of embedding all files in the knowledge object, which may require updates to custom integrations or scripts accessing knowledge base data programmatically. [Commit](https://github.com/open-webui/open-webui/commit/94a8439105f30203ea9d729787c9c5978f5c22a2)
|
|
||||||
- 🗑️ Legacy knowledge base support for deprecated document collections and tag-based collections was removed; users with pre-knowledge base documents must migrate to the current knowledge base system as legacy items will no longer appear in selectors or command menus. [Commit](https://github.com/open-webui/open-webui/commit/a934dc997ed67a036dd7975e380f8036c447d3ed)
|
|
||||||
- 🔨 Source-level log environment variables (AUDIO_LOG_LEVEL, CONFIG_LOG_LEVEL, MODELS_LOG_LEVEL, etc.) were removed as they provided limited configuration options and added significant complexity across 100+ files; the GLOBAL_LOG_LEVEL environment variable, which already took precedence over source-level settings, now serves as the exclusive logging configuration method. [#20045](https://github.com/open-webui/open-webui/pull/20045)
|
|
||||||
- 🐍 LangChain was upgraded to version 1.2.0, representing a major dependency update and significant progress toward Python 3.13 compatibility while improving RAG pipeline functionality for document loading and retrieval operations. [#19991](https://github.com/open-webui/open-webui/pull/19991)
|
|
||||||
|
|
||||||
## [0.6.41] - 2025-12-02
|
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- 🚦 Sign-in rate limiting was implemented to protect against brute force attacks, limiting login attempts to 15 per 3-minute window per email address using Redis with automatic fallback to in-memory storage when Redis is unavailable. [Commit](https://github.com/open-webui/open-webui/commit/7b166370432414ce8f186747fb098e0c70fb2d6b)
|
|
||||||
- 📂 Administrators can now globally disable the folders feature and control user-level folder permissions through the admin panel, enabling minimalist interface configurations for deployments that don't require workspace organization features. [#19529](https://github.com/open-webui/open-webui/pull/19529), [#19210](https://github.com/open-webui/open-webui/discussions/19210), [#18459](https://github.com/open-webui/open-webui/discussions/18459), [#18299](https://github.com/open-webui/open-webui/discussions/18299)
|
|
||||||
- 👥 Group channels were introduced as a new channel type enabling membership-based collaboration spaces where users explicitly join as members rather than accessing through permissions, with support for public or private visibility, automatic member inclusion from specified user groups, member role tracking with invitation metadata, and post-creation member management allowing channel managers to add or remove members through the channel info modal. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4), [Commit](https://github.com/open-webui/open-webui/commit/3f1d9ccbf8443a2fa5278f36202bad930a216680)
|
|
||||||
- 💬 Direct Message channels were introduced with a dedicated channel type selector and multi-user member selection interface, enabling private conversations between specific users without requiring full channel visibility. [Commit](https://github.com/open-webui/open-webui/commit/64b4d5d9c280b926746584aaf92b447d09deb386)
|
|
||||||
- 📨 Direct Message channels now support a complete user-to-user messaging system with member-based access control, automatic deduplication for one-on-one conversations, optional channel naming, and distinct visual presentation using participant avatars instead of channel icons. [Commit](https://github.com/open-webui/open-webui/commit/acccb9afdd557274d6296c70258bb897bbb6652f)
|
|
||||||
- 🙈 Users can now hide Direct Message channels from their sidebar while preserving message history, with automatic reactivation when new messages arrive from other participants, providing a cleaner interface for managing active conversations. [Commit](https://github.com/open-webui/open-webui/commit/acccb9afdd557274d6296c70258bb897bbb6652f)
|
|
||||||
- ☑️ A comprehensive user selection component was added to the channel creation modal, featuring search functionality, sortable user lists, pagination support, and multi-select checkboxes for building Direct Message participant lists. [Commit](https://github.com/open-webui/open-webui/commit/acccb9afdd557274d6296c70258bb897bbb6652f)
|
|
||||||
- 🔴 Channel unread message count tracking was implemented with visual badge indicators in the sidebar, automatically updating counts in real-time and marking messages as read when users view channels, with join/leave functionality to manage membership status. [Commit](https://github.com/open-webui/open-webui/commit/64b4d5d9c280b926746584aaf92b447d09deb386)
|
|
||||||
- 📌 Message pinning functionality was added to channels, allowing users to pin important messages for easy reference with visual highlighting, a dedicated pinned messages modal accessible from the navbar, and complete backend support for tracking pinned status, pin timestamp, and the user who pinned each message. [Commit](https://github.com/open-webui/open-webui/commit/64b4d5d9c280b926746584aaf92b447d09deb386), [Commit](https://github.com/open-webui/open-webui/commit/aae2fce17355419d9c29f8100409108037895201)
|
|
||||||
- 🟢 Direct Message channels now display an active status indicator for one-on-one conversations, showing a green dot when the other participant is currently online or a gray dot when offline. [Commit](https://github.com/open-webui/open-webui/commit/4b6773885cd7527c5a56b963781dac5e95105eec), [Commit](https://github.com/open-webui/open-webui/commit/39645102d14f34e71b34e5ddce0625790be33f6f)
|
|
||||||
- 🆔 Users can now start Direct Message conversations directly from user profile previews by clicking the "Message" button, enabling quick access to private messaging without navigating away from the current channel. [Commit](https://github.com/open-webui/open-webui/commit/a0826ec9fedb56320532616d568fa59dda831d4e)
|
|
||||||
- ⚡ Channel messages now appear instantly when sent using optimistic UI rendering, displaying with a pending state while the server confirms delivery, providing a more responsive messaging experience. [Commit](https://github.com/open-webui/open-webui/commit/25994dd3da90600401f53596d4e4fb067c1b8eaa)
|
|
||||||
- 👍 Channel message reactions now display the names of users who reacted when hovering over the emoji, showing up to three names with a count for additional reactors. [Commit](https://github.com/open-webui/open-webui/commit/05e79bdd0c7af70b631e958924e3656db1013b80)
|
|
||||||
- 🛠️ Channel creators can now edit and delete their own group and DM channels without requiring administrator privileges, enabling users to manage the channels they create independently. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4)
|
|
||||||
- 🔌 A new API endpoint was added to directly get or create a Direct Message channel with a specific user by their ID, streamlining programmatic DM channel creation for integrations and frontend workflows. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4)
|
|
||||||
- 💭 Users can now set a custom status with an emoji and message that displays in profile previews, the sidebar user menu, and Direct Message channel items in the sidebar, with the ability to clear status at any time, providing visibility into availability or current focus similar to team communication platforms. [Commit](https://github.com/open-webui/open-webui/commit/51621ba91a982e52da168ce823abffd11ad3e4fa), [Commit](https://github.com/open-webui/open-webui/commit/f5e8d4d5a004115489c35725408b057e24dfe318)
|
|
||||||
- 📤 A group export API endpoint was added, enabling administrators to export complete group data including member lists for backup and migration purposes. [Commit](https://github.com/open-webui/open-webui/commit/09b6ea38c579659f8ca43ae5ea3746df3ac561ad)
|
|
||||||
- 📡 A new API endpoint was added to retrieve all users belonging to a specific group, enabling programmatic access to group membership information for administrative workflows. [Commit](https://github.com/open-webui/open-webui/commit/01868e856a10f474f74fbd1b4425dafdf949222f)
|
|
||||||
- 👁️ The admin user list now displays an active status indicator next to each user, showing a visual green dot for users who have been active within the last three minutes. [Commit](https://github.com/open-webui/open-webui/commit/1b095d12ff2465b83afa94af89ded9593f8a8655)
|
|
||||||
- 🔑 The admin user edit modal now displays OAuth identity information with a per-provider breakdown, showing each linked identity provider and its associated subject identifier separately. [#19573](https://github.com/open-webui/open-webui/pull/19573)
|
|
||||||
- 🧩 OAuth role claim parsing now respects the "OAUTH_ROLES_SEPARATOR" configuration, enabling proper parsing of roles returned as comma-separated strings and providing consistent behavior with group claim handling. [#19514](https://github.com/open-webui/open-webui/pull/19514)
|
|
||||||
- 🎛️ Channel feature access can now be controlled through both the "USER_PERMISSIONS_FEATURES_CHANNELS" environment variable and group permission toggles in the admin panel, allowing administrators to restrict channel functionality for specific users or groups while defaulting to enabled for all users. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4)
|
|
||||||
- 🎨 The model editor interface was refined with access control settings moved to a dedicated modal, group member counts now displayed when configuring permissions, reorganized layout with improved visual hierarchy, and redesigned prompt suggestions cards with tooltips for field guidance. [Commit](https://github.com/open-webui/open-webui/commit/e65d92fc6f49da5ca059e1c65a729e7973354b99), [Commit](https://github.com/open-webui/open-webui/commit/9d39b9b42c653ee2acf2674b2df343ecbceb4954)
|
|
||||||
- 🏗️ Knowledge base file management was rebuilt with a dedicated database table replacing the previous JSON array storage, enabling pagination support for large knowledge bases, significantly faster file listing performance, and more reliable file-knowledge base relationship tracking. [Commit](https://github.com/open-webui/open-webui/commit/d19023288e2ca40f86e2dc3fd9f230540f3e70d7)
|
|
||||||
- ☁️ Azure Document Intelligence model selection was added, allowing administrators to specify which model to use for document processing via the "DOCUMENT_INTELLIGENCE_MODEL" environment variable or admin UI setting, with "prebuilt-layout" as the default. [#19692](https://github.com/open-webui/open-webui/pull/19692), [Docs:#872](https://github.com/open-webui/docs/pull/872)
|
|
||||||
- 🚀 Milvus multitenancy vector database performance was improved by removing manual flush calls after upsert operations, eliminating rate limit errors and reducing load on etcd and MinIO/S3 storage by allowing Milvus to manage segment persistence automatically via its WAL and auto-flush policies. [#19680](https://github.com/open-webui/open-webui/pull/19680)
|
|
||||||
- ✨ Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
|
||||||
- 🌍 Translations for German, French, Portuguese (Brazil), Catalan, Simplified Chinese, and Traditional Chinese were enhanced and expanded.
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- 🔄 Tool call response token duplication was fixed by removing redundant message history additions in non-native function calling mode, resolving an issue where tool results were included twice in the context and causing 2x token consumption. [#19656](https://github.com/open-webui/open-webui/issues/19656), [Commit](https://github.com/open-webui/open-webui/commit/52ccab8)
|
|
||||||
- 🛡️ Web search domain filtering was corrected to properly block results when any resolved hostname or IP address matches a blocked domain, preventing blocked sites from appearing in search results due to permissive hostname resolution logic that previously allowed results through if any single resolved address passed the filter. [#19670](https://github.com/open-webui/open-webui/pull/19670), [#19669](https://github.com/open-webui/open-webui/issues/19669)
|
|
||||||
- 🧠 Custom models based on Ollama or OpenAI now properly inherit the connection type from their base model, ensuring they appear correctly in the "Local" or "External" model selection tabs instead of only appearing under "All". [#19183](https://github.com/open-webui/open-webui/issues/19183), [Commit](https://github.com/open-webui/open-webui/commit/39f7575)
|
|
||||||
- 🐍 SentenceTransformers embedding initialization was fixed by updating the transformers dependency to version 4.57.3, resolving a regression in v0.6.40 where document ingestion failed with "'NoneType' object has no attribute 'encode'" errors due to a bug in transformers 4.57.2. [#19512](https://github.com/open-webui/open-webui/issues/19512), [#19513](https://github.com/open-webui/open-webui/pull/19513)
|
|
||||||
- 📈 Active user count accuracy was significantly improved by replacing the socket-based USER_POOL tracking with a database-backed heartbeat mechanism, resolving long-standing issues where Redis deployments displayed inflated user counts due to stale sessions never being cleaned up on disconnect. [#16074](https://github.com/open-webui/open-webui/discussions/16074), [Commit](https://github.com/open-webui/open-webui/commit/70948f8803e417459d5203839f8077fdbfbbb213)
|
|
||||||
- 👥 Default group assignment now applies consistently across all user registration methods including OAuth/SSO, LDAP, and admin-created users, fixing an issue where the "DEFAULT_GROUP_ID" setting was only being applied to users who signed up via the email/password signup form. [#19685](https://github.com/open-webui/open-webui/pull/19685)
|
|
||||||
- 🔦 Model list filtering in workspaces was corrected to properly include models shared with user groups, ensuring members can view models they have write access to through group permissions. [#19461](https://github.com/open-webui/open-webui/issues/19461), [Commit](https://github.com/open-webui/open-webui/commit/69722ba973768a5f689f2e2351bf583a8db9bba8)
|
|
||||||
- 🖼️ User profile image display in preview contexts was fixed by resolving a Pydantic validation error that prevented proper rendering. [Commit](https://github.com/open-webui/open-webui/commit/c7eb7136893b0ddfdc5d55ffc7a05bd84a00f5d6)
|
|
||||||
- 🔒 Redis TLS connection failures were resolved by updating the python-socketio dependency to version 5.15.0, restoring support for the "rediss://" URL schema. [#19480](https://github.com/open-webui/open-webui/issues/19480), [#19488](https://github.com/open-webui/open-webui/pull/19488)
|
|
||||||
- 📝 MCP tool server configuration was corrected to properly handle the "Function Name Filter List" as both string and list types, preventing AttributeError when the field is empty and ensuring backward compatibility. [#19486](https://github.com/open-webui/open-webui/issues/19486), [Commit](https://github.com/open-webui/open-webui/commit/c5b73d71843edc024325d4a6e625ec939a747279), [Commit](https://github.com/open-webui/open-webui/commit/477097c2e42985c14892301d0127314629d07df1)
|
|
||||||
- 📎 Web page attachment failures causing TypeError on metadata checks were resolved by correcting async threadpool parameter passing in vector database operations. [#19493](https://github.com/open-webui/open-webui/issues/19493), [Commit](https://github.com/open-webui/open-webui/commit/4370dee79e19d77062c03fba81780cb3b779fca3)
|
|
||||||
- 💾 Model allowlist persistence in multi-worker deployments was fixed by implementing Redis-based shared state for the internal models dictionary, ensuring configuration changes are consistently visible across all worker processes. [#19395](https://github.com/open-webui/open-webui/issues/19395), [Commit](https://github.com/open-webui/open-webui/commit/b5e5617d7f7ad3e4eec9f15f4cc7f07cb5afc2fa)
|
|
||||||
- ⏳ Chat history infinite loading was prevented by enhancing message data structure to properly track parent message relationships, resolving issues where missing parentId fields caused perpetual loading states. [#19225](https://github.com/open-webui/open-webui/issues/19225), [Commit](https://github.com/open-webui/open-webui/commit/ff4b1b9862d15adfa15eac17d2ce066c3d8ae38f)
|
|
||||||
- 🩹 Database migration robustness was improved by automatically detecting and correcting missing primary key constraints on the user table, ensuring successful schema upgrades for databases with non-standard configurations. [#19487](https://github.com/open-webui/open-webui/discussions/19487), [Commit](https://github.com/open-webui/open-webui/commit/453ea9b9a167c0b03d86c46e6efd086bf10056ce)
|
|
||||||
- 🏷️ OAuth group assignment now updates correctly on first login when users transition from admin to user role, ensuring group memberships reflect immediately when group management is enabled. [#19475](https://github.com/open-webui/open-webui/issues/19475), [#19476](https://github.com/open-webui/open-webui/pull/19476)
|
|
||||||
- 💡 Knowledge base file tooltips now properly display the parent collection name when referencing files with the hash symbol, preventing confusion between identically-named files in different collections. [#19491](https://github.com/open-webui/open-webui/issues/19491), [Commit](https://github.com/open-webui/open-webui/commit/3fe5a47b0ff84ac97f8e4ff56a19fa2ec065bf66)
|
|
||||||
- 🔐 Knowledge base file access inconsistencies were resolved where authorized non-admin users received "Not found" or permission errors for certain files due to race conditions during upload causing mismatched collection_name values, with file access validation now properly checking against knowledge base file associations. [#18689](https://github.com/open-webui/open-webui/issues/18689), [#19523](https://github.com/open-webui/open-webui/pull/19523), [Commit](https://github.com/open-webui/open-webui/commit/e301d1962e45900ababd3eabb7e9a2ad275a5761)
|
|
||||||
- 📦 Knowledge API batch file addition endpoint was corrected to properly handle async operations, resolving 500 Internal Server Error responses when adding multiple files simultaneously. [#19538](https://github.com/open-webui/open-webui/issues/19538), [Commit](https://github.com/open-webui/open-webui/commit/28659f60d94feb4f6a99bb1a5b54d7f45e5ea10f)
|
|
||||||
- 🤖 Embedding model auto-update functionality was fixed to properly respect the "RAG_EMBEDDING_MODEL_AUTO_UPDATE" setting by correctly passing the flag to the model path resolver, ensuring models update as expected when the auto-update option is enabled. [#19687](https://github.com/open-webui/open-webui/pull/19687)
|
|
||||||
- 📉 API response payload sizes were dramatically reduced by removing base64-encoded profile images from most endpoints, eliminating multi-megabyte responses caused by high-resolution avatars and enabling better browser caching. [#19519](https://github.com/open-webui/open-webui/issues/19519), [Commit](https://github.com/open-webui/open-webui/commit/384753c4c17f62a68d38af4bbcf55a21ee08e0f2)
|
|
||||||
- 📞 Redundant API calls on the admin user overview page were eliminated by consolidating reactive statements, reducing four duplicate requests to a single efficient call and significantly improving page load performance. [#19509](https://github.com/open-webui/open-webui/issues/19509), [Commit](https://github.com/open-webui/open-webui/commit/9f89cc5e9f7e1c6c9e2bc91177e08df7c79f66f9)
|
|
||||||
- 🧹 Duplicate API calls on the workspace models page were eliminated by removing redundant model list fetching, reducing two identical requests to a single call and improving page responsiveness. [#19517](https://github.com/open-webui/open-webui/issues/19517), [Commit](https://github.com/open-webui/open-webui/commit/d1bbf6be7a4d1d53fa8ad46ca4f62fc4b2e6a8cb)
|
|
||||||
- 🔘 The model valves button was corrected to prevent unintended form submission by adding explicit button type attribute, ensuring it no longer triggers message sending when the input area contains text. [#19534](https://github.com/open-webui/open-webui/pull/19534)
|
|
||||||
- 🗑️ Ollama model deletion was fixed by correcting the request payload format and ensuring the model selector properly displays the placeholder option. [Commit](https://github.com/open-webui/open-webui/commit/0f3156651c64bc5af188a65fc2908bdcecf30c74)
|
|
||||||
- 🎨 Image generation in temporary chats was fixed by correctly handling local chat sessions that are not persisted to the database. [Commit](https://github.com/open-webui/open-webui/commit/a7c7993bbf3a21cb7ba416525b89233cf2ad877f)
|
|
||||||
- 🕵️♂️ Audit logging was fixed by correctly awaiting the async user authentication call, resolving failures where coroutine objects were passed instead of user data. [#19658](https://github.com/open-webui/open-webui/pull/19658), [Commit](https://github.com/open-webui/open-webui/commit/dba86bc)
|
|
||||||
- 🌙 Dark mode select dropdown styling was corrected to use proper background colors, fixing an issue where dropdown borders and hover states appeared white instead of matching the dark theme. [#19693](https://github.com/open-webui/open-webui/pull/19693), [#19442](https://github.com/open-webui/open-webui/issues/19442)
|
|
||||||
- 🔍 Milvus vector database query filtering was fixed by correcting string quote handling in filter expressions and using the proper parameter name for queries, resolving false "duplicate content detected" errors that prevented uploading multiple files to knowledge bases. [#19602](https://github.com/open-webui/open-webui/pull/19602), [#18119](https://github.com/open-webui/open-webui/issues/18119), [#16345](https://github.com/open-webui/open-webui/issues/16345), [#17088](https://github.com/open-webui/open-webui/issues/17088), [#18485](https://github.com/open-webui/open-webui/issues/18485)
|
|
||||||
- 🆙 Milvus multitenancy vector database was updated to use query_iterator() for improved robustness and consistency with the standard Milvus implementation, fixing the same false duplicate detection errors and improving handling of large result sets in multi-tenant deployments. [#19695](https://github.com/open-webui/open-webui/pull/19695)
|
|
||||||
|
|
||||||
### Changed
|
|
||||||
|
|
||||||
- ⚠️ **IMPORTANT for Multi-Instance Deployments** — This release includes database schema changes; multi-worker, multi-server, or load-balanced deployments must update all instances simultaneously rather than performing rolling updates, as running mixed versions will cause application failures due to schema incompatibility between old and new instances.
|
|
||||||
- 👮 Channel creation is now restricted to administrators only, with the channel add button hidden for regular users to maintain organizational control over communication channels. [Commit](https://github.com/open-webui/open-webui/commit/421aba7cd7cd708168b1f2565026c74525a67905)
|
|
||||||
- ➖ The active user count indicator was removed from the bottom-left user menu in the sidebar to streamline the interface. [Commit](https://github.com/open-webui/open-webui/commit/848f3fd4d86ca66656e0ff0335773945af8d7d8d)
|
|
||||||
- 🗂️ The user table was restructured with API keys migrated to a dedicated table supporting future multi-key functionality, OAuth data storage converted to a JSON structure enabling multiple identity providers per user account, and internal column types optimized from TEXT to JSON for the "info" and "settings" fields, with automatic migration preserving all existing data and associations. [#19573](https://github.com/open-webui/open-webui/pull/19573)
|
|
||||||
- 🔄 The knowledge base API was restructured to support the new file relationship model.
|
|
||||||
|
|
||||||
## [0.6.40] - 2025-11-25
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- 🗄️ A critical PostgreSQL user listing performance issue was resolved by removing a redundant count operation that caused severe database slowdowns and potential timeouts when viewing user lists in admin panels.
|
|
||||||
|
|
||||||
## [0.6.39] - 2025-11-25
|
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- 💬 A user list modal was added to channels, displaying all users with access and featuring search, sorting, and pagination capabilities. [Commit](https://github.com/open-webui/open-webui/commit/c0e120353824be00a2ef63cbde8be5d625bd6fd0)
|
|
||||||
- 💬 Channel navigation now displays the total number of users with access to the channel. [Commit](https://github.com/open-webui/open-webui/commit/3b5710d0cd445cf86423187f5ee7c40472a0df0b)
|
|
||||||
- 🔌 Tool servers and MCP connections now support function name filtering, allowing administrators to selectively enable or block specific functions using allow/block lists. [Commit](https://github.com/open-webui/open-webui/commit/743199f2d097ae1458381bce450d9025a0ab3f3d)
|
|
||||||
- ⚡ A toggle to disable parallel embedding processing was added via "ENABLE_ASYNC_EMBEDDING", allowing sequential processing for rate-limited or resource-constrained local embedding setups. [#19444](https://github.com/open-webui/open-webui/pull/19444)
|
|
||||||
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
|
||||||
- 🌐 Localization improvements were made for German (de-DE) and Portuguese (Brazil) translations.
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- 📝 Inline citations now render correctly within markdown lists and nested elements instead of displaying as "undefined" values. [#19452](https://github.com/open-webui/open-webui/issues/19452)
|
|
||||||
- 👥 Group member selection now works correctly without randomly selecting other users or causing the user list to jump around. [#19426](https://github.com/open-webui/open-webui/issues/19426)
|
|
||||||
- 👥 Admin panel user list now displays the correct total user count and properly paginates 30 items per page after fixing database query issues with group member joins. [#19429](https://github.com/open-webui/open-webui/issues/19429)
|
|
||||||
- 🔍 Knowledge base reindexing now works correctly after resolving async execution chain issues by implementing threadpool workers for embedding operations. [#19434](https://github.com/open-webui/open-webui/pull/19434)
|
|
||||||
- 🖼️ OpenAI image generation now works correctly after fixing a connection adapter error caused by incorrect URL formatting. [#19435](https://github.com/open-webui/open-webui/pull/19435)
|
|
||||||
|
|
||||||
### Changed
|
|
||||||
|
|
||||||
- 🔧 BREAKING: Docling configuration has been consolidated from individual environment variables into a single "DOCLING_PARAMS" JSON configuration and now supports API key authentication via "DOCLING_API_KEY", requiring users to migrate existing Docling settings to the new format. [#16841](https://github.com/open-webui/open-webui/issues/16841), [#19427](https://github.com/open-webui/open-webui/pull/19427)
|
|
||||||
- 🔧 The environment variable "REPLACE_IMAGE_URLS_IN_CHAT_RESPONSE" has been renamed to "ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION" for naming consistency.
|
|
||||||
|
|
||||||
## [0.6.38] - 2025-11-24
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- 🔍 Hybrid search now works reliably after recent changes.
|
|
||||||
- 🛠️ Tool server saving now handles errors gracefully, preventing failed saves from impacting the UI.
|
|
||||||
- 🔐 SSO/OIDC code fixed to improve login reliability and better handle edge cases.
|
|
||||||
|
|
||||||
## [0.6.37] - 2025-11-24
|
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- 🔐 Granular sharing permissions are now available with two-tiered control separating group sharing from public sharing, allowing administrators to independently configure whether users can share workspace items with groups or make them publicly accessible, with separate permission toggles for models, knowledge bases, prompts, tools, and notes, configurable via "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING", "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", and corresponding environment variables for other workspace item types, while groups can now be configured to opt-out of sharing via the "Allow Group Sharing" setting. [Commit](https://github.com/open-webui/open-webui/commit/7be750bcbb40da91912a0a66b7ab791effdcc3b6), [Commit](https://github.com/open-webui/open-webui/commit/f69e37a8507d6d57382d6670641b367f3127f90a)
|
|
||||||
- 🔐 Password policy enforcement is now available with configurable validation rules, allowing administrators to require specific password complexity requirements via "ENABLE_PASSWORD_VALIDATION" and "PASSWORD_VALIDATION_REGEX_PATTERN" environment variables, with default pattern requiring minimum 8 characters including uppercase, lowercase, digit, and special character. [#17794](https://github.com/open-webui/open-webui/pull/17794)
|
|
||||||
- 🔐 Granular import and export permissions are now available for workspace items, introducing six separate permission toggles for models, prompts, and tools that are disabled by default for enhanced security. [#19242](https://github.com/open-webui/open-webui/pull/19242)
|
|
||||||
- 👥 Default group assignment is now available for new users, allowing administrators to automatically assign newly registered users to a specified group for streamlined access control to models, prompts, and tools, particularly useful for organizations with group-based model access policies. [#19325](https://github.com/open-webui/open-webui/pull/19325), [#17842](https://github.com/open-webui/open-webui/issues/17842)
|
|
||||||
- 🔒 Password-based authentication can now be fully disabled via "ENABLE_PASSWORD_AUTH" environment variable, enforcing SSO-only authentication and preventing password login fallback when SSO is configured. [#19113](https://github.com/open-webui/open-webui/pull/19113)
|
|
||||||
- 🖼️ Large stream chunk handling was implemented to support models that generate images directly in their output responses, with configurable buffer size via "CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE" environment variable, resolving compatibility issues with models like Gemini 2.5 Flash Image. [#18884](https://github.com/open-webui/open-webui/pull/18884), [#17626](https://github.com/open-webui/open-webui/issues/17626)
|
|
||||||
- 🖼️ Streaming response middleware now handles images in delta updates with automatic base64 conversion, enabling proper display of images from models using the "choices[0].delta.images.image_url" format such as Gemini 2.5 Flash Image Preview on OpenRouter. [#19073](https://github.com/open-webui/open-webui/pull/19073), [#19019](https://github.com/open-webui/open-webui/issues/19019)
|
|
||||||
- 📈 Model list API performance was optimized by pre-fetching user group memberships and removing profile image URLs from response payloads, significantly reducing both database queries and payload size for instances with large model lists, with profile images now served dynamically via dedicated endpoints. [#19097](https://github.com/open-webui/open-webui/pull/19097), [#18950](https://github.com/open-webui/open-webui/issues/18950)
|
|
||||||
- ⏩ Batch file processing performance was improved by reducing database queries by 67% while ensuring data consistency between vector and relational databases. [#18953](https://github.com/open-webui/open-webui/pull/18953)
|
|
||||||
- 🚀 Chat import performance was dramatically improved by replacing individual per-chat API requests with a bulk import endpoint, reducing import time by up to 95% for large chat collections and providing user feedback via toast notifications displaying the number of successfully imported chats. [#17861](https://github.com/open-webui/open-webui/pull/17861)
|
|
||||||
- ⚡ Socket event broadcasting performance was optimized by implementing user-specific rooms, significantly reducing server overhead particularly for users with multiple concurrent sessions. [#18996](https://github.com/open-webui/open-webui/pull/18996)
|
|
||||||
- 🗄️ Weaviate is now supported as a vector database option, providing an additional choice for RAG document storage alongside existing ChromaDB, Milvus, Qdrant, and OpenSearch integrations. [#14747](https://github.com/open-webui/open-webui/pull/14747)
|
|
||||||
- 🗄️ PostgreSQL pgvector now supports HNSW index types and large dimensional embeddings exceeding 2000 dimensions through automatic halfvec type selection, with configurable index methods via "PGVECTOR_INDEX_METHOD", "PGVECTOR_HNSW_M", "PGVECTOR_HNSW_EF_CONSTRUCTION", and "PGVECTOR_IVFFLAT_LISTS" environment variables. [#19158](https://github.com/open-webui/open-webui/pull/19158), [#16890](https://github.com/open-webui/open-webui/issues/16890)
|
|
||||||
- 🔍 Azure AI Search is now supported as a web search provider, enabling integration with Azure's cognitive search services via "AZURE_AI_SEARCH_API_KEY", "AZURE_AI_SEARCH_ENDPOINT", and "AZURE_AI_SEARCH_INDEX_NAME" configuration. [#19104](https://github.com/open-webui/open-webui/pull/19104)
|
|
||||||
- ⚡ External embedding generation now processes API requests in parallel instead of sequential batches, reducing document processing time by 10-50x when using OpenAI, Azure OpenAI, or Ollama embedding providers, with large PDFs now processing in seconds instead of minutes. [#19296](https://github.com/open-webui/open-webui/pull/19296)
|
|
||||||
- 💨 Base64 image conversion is now available for markdown content in chat responses, automatically uploading embedded images exceeding 1KB and replacing them with file URLs to reduce payload size and resource consumption, configurable via "REPLACE_IMAGE_URLS_IN_CHAT_RESPONSE" environment variable. [#19076](https://github.com/open-webui/open-webui/pull/19076)
|
|
||||||
- 🎨 OpenAI image generation now supports additional API parameters including quality settings for GPT Image 1, configurable via "IMAGES_OPENAI_API_PARAMS" environment variable or through the admin interface, enabling cost-effective image generation with low, medium, or high quality options. [#19228](https://github.com/open-webui/open-webui/issues/19228)
|
|
||||||
- 🖼️ Image editing can now be independently enabled or disabled via admin settings, allowing administrators to control whether sequential image prompts trigger image editing or new image generation, configurable via "ENABLE_IMAGE_EDIT" environment variable. [#19284](https://github.com/open-webui/open-webui/issues/19284)
|
|
||||||
- 🔐 SSRF protection was implemented with a configurable URL blocklist that prevents access to cloud metadata endpoints and private networks, with default protections for AWS, Google Cloud, Azure, and Alibaba Cloud metadata services, customizable via "WEB_FETCH_FILTER_LIST" environment variable. [#19201](https://github.com/open-webui/open-webui/pull/19201)
|
|
||||||
- ⚡ Workspace models page now supports server-side pagination dramatically improving load times and usability for instances with large numbers of workspace models.
|
|
||||||
- 🔍 Hybrid search now indexes file metadata including filenames, titles, headings, sources, and snippets alongside document content, enabling keyword queries to surface documents where search terms appear only in metadata, configurable via "ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS" environment variable. [#19095](https://github.com/open-webui/open-webui/pull/19095)
|
|
||||||
- 📂 Knowledge base upload page now supports folder drag-and-drop with recursive directory handling, enabling batch uploads of entire directory structures instead of requiring individual file selection. [#19320](https://github.com/open-webui/open-webui/pull/19320)
|
|
||||||
- 🤖 Model cloning is now available in admin settings, allowing administrators to quickly create workspace models based on existing base models through a "Clone" option in the model dropdown menu. [#17937](https://github.com/open-webui/open-webui/pull/17937)
|
|
||||||
- 🎨 UI scale adjustment is now available in interface settings, allowing users to increase the size of the entire interface from 1.0x to 1.5x for improved accessibility and readability, particularly beneficial for users with visual impairments. [#19186](https://github.com/open-webui/open-webui/pull/19186)
|
|
||||||
- 📌 Default pinned models can now be configured by administrators for all new users, mirroring the behavior of default models where admin-configured defaults apply only to users who haven't customized their pinned models, configurable via "DEFAULT_PINNED_MODELS" environment variable. [#19273](https://github.com/open-webui/open-webui/pull/19273)
|
|
||||||
- 🎙️ Text-to-Speech and Speech-to-Text services now receive user information headers when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled, allowing external TTS and STT providers to implement user-specific personalization, rate limiting, and usage tracking. [#19323](https://github.com/open-webui/open-webui/pull/19323), [#19312](https://github.com/open-webui/open-webui/issues/19312)
|
|
||||||
- 🎙️ Voice mode now supports custom system prompts via "VOICE_MODE_PROMPT_TEMPLATE" configuration, allowing administrators to control response style and behavior for voice interactions. [#18607](https://github.com/open-webui/open-webui/pull/18607)
|
|
||||||
- 🔧 WebSocket and Redis configuration options are now available including debug logging controls, custom ping timeout and interval settings, and arbitrary Redis connection options via "WEBSOCKET_SERVER_LOGGING", "WEBSOCKET_SERVER_ENGINEIO_LOGGING", "WEBSOCKET_SERVER_PING_TIMEOUT", "WEBSOCKET_SERVER_PING_INTERVAL", and "WEBSOCKET_REDIS_OPTIONS" environment variables. [#19091](https://github.com/open-webui/open-webui/pull/19091)
|
|
||||||
- 🔧 MCP OAuth dynamic client registration now automatically detects and uses the appropriate token endpoint authentication method from server-supported options, enabling compatibility with OAuth servers that only support "client_secret_basic" instead of "client_secret_post". [#19193](https://github.com/open-webui/open-webui/issues/19193)
|
|
||||||
- 🔧 Custom headers can now be configured for remote MCP and OpenAPI tool server connections, enabling integration with services that require additional authentication headers. [#18918](https://github.com/open-webui/open-webui/issues/18918)
|
|
||||||
- 🔍 Perplexity Search now supports custom API endpoints via "PERPLEXITY_SEARCH_API_URL" configuration and automatically forwards user information headers to enable personalized search experiences. [#19147](https://github.com/open-webui/open-webui/pull/19147)
|
|
||||||
- 🔍 User information headers can now be optionally forwarded to external web search engines when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled. [#19043](https://github.com/open-webui/open-webui/pull/19043)
|
|
||||||
- 📊 Daily active user metric is now available for monitoring, tracking unique users active since midnight UTC via the "webui.users.active.today" Prometheus gauge. [#19236](https://github.com/open-webui/open-webui/pull/19236), [#19234](https://github.com/open-webui/open-webui/issues/19234)
|
|
||||||
- 📊 Audit log file path is now configurable via "AUDIT_LOGS_FILE_PATH" environment variable, enabling storage in separate volumes or custom locations. [#19173](https://github.com/open-webui/open-webui/pull/19173)
|
|
||||||
- 🎨 Sidebar collapse states for model lists and group information are now persistent across page refreshes, remembering user preferences through browser-based storage. [#19159](https://github.com/open-webui/open-webui/issues/19159)
|
|
||||||
- 🎨 Background image display was enhanced with semi-transparent overlays for navbar and sidebar, creating a seamless and visually cohesive design across the entire interface. [#19157](https://github.com/open-webui/open-webui/issues/19157)
|
|
||||||
- 📋 Tables in chat messages now include a copy button that appears on hover, enabling quick copying of table content alongside the existing CSV export functionality. [#19162](https://github.com/open-webui/open-webui/issues/19162)
|
|
||||||
- 📝 Notes can now be created directly via the "/notes/new" URL endpoint with optional title and content query parameters, enabling faster note creation through bookmarks and shortcuts. [#19195](https://github.com/open-webui/open-webui/issues/19195)
|
|
||||||
- 🏷️ Tag suggestions are now context-aware, displaying only relevant tags when creating or editing models versus chat conversations, preventing confusion between model and chat tags. [#19135](https://github.com/open-webui/open-webui/issues/19135)
|
|
||||||
- ✍️ Prompt autocompletion is now available independently of the rich text input setting, improving accessibility to the feature. [#19150](https://github.com/open-webui/open-webui/issues/19150)
|
|
||||||
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
|
||||||
- 🌐 Translations for Simplified Chinese, Traditional Chinese, Portuguese (Brazil), Catalan, Spanish (Spain), Finnish, Irish, Farsi, Swedish, Danish, German, Korean, and Thai were improved and expanded.
|
|
||||||
|
|
||||||
### Fixed
|
|
||||||
|
|
||||||
- 🤖 Model update functionality now works correctly, resolving a database parameter binding error that prevented saving changes to model configurations via the Save & Update button. [#19335](https://github.com/open-webui/open-webui/issues/19335)
|
|
||||||
- 🖼️ Multiple input images for image editing and generation are now correctly passed as an array using the "image[]" parameter syntax, enabling proper multi-image reference functionality with models like GPT Image 1. [#19339](https://github.com/open-webui/open-webui/issues/19339)
|
|
||||||
- 📱 PWA installations on iOS now properly refresh after server container restarts, resolving freezing issues by automatically unregistering service workers when version or deployment changes are detected. [#19316](https://github.com/open-webui/open-webui/pull/19316)
|
|
||||||
- 🗄️ S3 Vectors collection detection now correctly handles buckets with more than 2000 indexes by using direct index lookup instead of paginated list scanning, improving performance by approximately 8x and enabling RAG queries to work reliably at scale. [#19238](https://github.com/open-webui/open-webui/pull/19238), [#19233](https://github.com/open-webui/open-webui/issues/19233)
|
|
||||||
- 📈 Feedback retrieval performance was optimized by eliminating N+1 query patterns through database joins, adding server-side pagination and sorting, significantly reducing database load for instances with large feedback datasets. [#17976](https://github.com/open-webui/open-webui/pull/17976)
|
|
||||||
- 🔍 Chat search now works correctly with PostgreSQL when chat data contains null bytes, with comprehensive sanitization preventing null bytes during data writes, cleaning existing data on read, and stripping null bytes during search queries to ensure reliable search functionality. [#15616](https://github.com/open-webui/open-webui/issues/15616)
|
|
||||||
- 🔍 Hybrid search with reranking now correctly handles attribute validation, preventing errors when collection results lack expected structure. [#19025](https://github.com/open-webui/open-webui/pull/19025), [#17046](https://github.com/open-webui/open-webui/issues/17046)
|
|
||||||
- 🔎 Reranking functionality now works correctly after recent refactoring, resolving crashes caused by incorrect function argument handling. [#19270](https://github.com/open-webui/open-webui/pull/19270)
|
|
||||||
- 🤖 Azure OpenAI models now support the "reasoning_effort" parameter, enabling proper configuration of reasoning capabilities for models like GPT-5.1 which default to no reasoning without this setting. [#19290](https://github.com/open-webui/open-webui/issues/19290)
|
|
||||||
- 🤖 Models with very long IDs can now be deleted correctly, resolving URL length limitations that previously prevented management operations on such models. [#18230](https://github.com/open-webui/open-webui/pull/18230)
|
|
||||||
- 🤖 Model-level streaming settings now correctly apply to API requests, ensuring "Stream Chat Response" toggle properly controls the streaming parameter. [#19154](https://github.com/open-webui/open-webui/issues/19154)
|
|
||||||
- 🖼️ Image editing configuration now correctly preserves independent OpenAI API endpoints and keys, preventing them from being overwritten by image generation settings. [#19003](https://github.com/open-webui/open-webui/issues/19003)
|
|
||||||
- 🎨 Gemini image edit settings now display correctly in the admin panel, fixing an incorrect configuration key reference that prevented proper rendering of edit options. [#19200](https://github.com/open-webui/open-webui/pull/19200)
|
|
||||||
- 🖌️ Image generation settings menu now loads correctly, resolving validation errors with AUTOMATIC1111 API authentication parameters. [#19187](https://github.com/open-webui/open-webui/issues/19187), [#19246](https://github.com/open-webui/open-webui/issues/19246)
|
|
||||||
- 📅 Date formatting in chat search and admin user chat search now correctly respects the "DEFAULT_LOCALE" environment variable, displaying dates according to the configured locale instead of always using MM/DD/YYYY format. [#19305](https://github.com/open-webui/open-webui/pull/19305), [#19020](https://github.com/open-webui/open-webui/issues/19020)
|
|
||||||
- 📝 RAG template query placeholder escaping logic was corrected to prevent unintended replacements of context values when query placeholders appear in retrieved content. [#19102](https://github.com/open-webui/open-webui/pull/19102), [#19101](https://github.com/open-webui/open-webui/issues/19101)
|
|
||||||
- 📄 RAG template prompt duplication was eliminated by removing redundant user query section from the default template. [#19099](https://github.com/open-webui/open-webui/pull/19099), [#19098](https://github.com/open-webui/open-webui/issues/19098)
|
|
||||||
- 📋 MinerU local mode configuration no longer incorrectly requires an API key, allowing proper use of local content extraction without external API credentials. [#19258](https://github.com/open-webui/open-webui/issues/19258)
|
|
||||||
- 📊 Excel file uploads now work correctly with the addition of the missing msoffcrypto-tool dependency, resolving import errors introduced by the unstructured package upgrade. [#19153](https://github.com/open-webui/open-webui/issues/19153)
|
|
||||||
- 📑 Docling parameters now properly handle JSON serialization, preventing exceptions and ensuring configuration changes are saved correctly. [#19072](https://github.com/open-webui/open-webui/pull/19072)
|
|
||||||
- 🛠️ UserValves configuration now correctly isolates settings per tool, preventing configuration contamination when multiple tools with UserValves are used simultaneously. [#19185](https://github.com/open-webui/open-webui/pull/19185), [#15569](https://github.com/open-webui/open-webui/issues/15569)
|
|
||||||
- 🔧 Tool selection prompt now correctly handles user messages without duplication, removing redundant query prefixes and improving prompt clarity. [#19122](https://github.com/open-webui/open-webui/pull/19122), [#19121](https://github.com/open-webui/open-webui/issues/19121)
|
|
||||||
- 📝 Notes chat feature now correctly submits messages to the completions endpoint, resolving errors that prevented AI model interactions. [#19079](https://github.com/open-webui/open-webui/pull/19079)
|
|
||||||
- 📝 Note PDF downloads now sanitize HTML content using DOMPurify before rendering, preventing potential DOM-based XSS attacks from malicious content in notes. [Commit](https://github.com/open-webui/open-webui/commit/03cc6ce8eb5c055115406e2304fbf7e3338b8dce)
|
|
||||||
- 📁 Archived chats now have their folder associations automatically removed to prevent unintended deletion when their previous folder is deleted. [#14578](https://github.com/open-webui/open-webui/issues/14578)
|
|
||||||
- 🔐 ElevenLabs API key is now properly obfuscated in the admin settings page, preventing plain text exposure of sensitive credentials. [#19262](https://github.com/open-webui/open-webui/pull/19262), [#19260](https://github.com/open-webui/open-webui/issues/19260)
|
|
||||||
- 🔧 MCP OAuth server metadata discovery now follows the correct specification order, ensuring proper authentication flow compliance. [#19244](https://github.com/open-webui/open-webui/pull/19244)
|
|
||||||
- 🔒 API key endpoint restrictions now properly enforce access controls for all endpoints including SCIM, preventing unintended access when "API_KEY_ALLOWED_ENDPOINTS" is configured. [#19168](https://github.com/open-webui/open-webui/issues/19168)
|
|
||||||
- 🔓 OAuth role claim parsing now supports both flat and nested claim structures, enabling compatibility with OAuth providers that deliver claims as direct properties on the user object rather than nested structures. [#19286](https://github.com/open-webui/open-webui/pull/19286)
|
|
||||||
- 🔑 OAuth MCP server verification now correctly extracts the access token value for authorization headers instead of sending the entire token dictionary. [#19149](https://github.com/open-webui/open-webui/pull/19149), [#19148](https://github.com/open-webui/open-webui/issues/19148)
|
|
||||||
- ⚙️ OAuth dynamic client registration now correctly converts empty strings to None for optional fields, preventing validation failures in MCP package integration. [#19144](https://github.com/open-webui/open-webui/pull/19144), [#19129](https://github.com/open-webui/open-webui/issues/19129)
|
|
||||||
- 🔐 OIDC authentication now correctly passes client credentials in access token requests, ensuring compatibility with providers that require these parameters per RFC 6749. [#19132](https://github.com/open-webui/open-webui/pull/19132), [#19131](https://github.com/open-webui/open-webui/issues/19131)
|
|
||||||
- 🔗 OAuth client creation now respects configured token endpoint authentication methods instead of defaulting to basic authentication, preventing failures with servers that don't support basic auth. [#19165](https://github.com/open-webui/open-webui/pull/19165)
|
|
||||||
- 📋 Text copied from chat responses in Chrome now pastes without background formatting, improving readability when pasting into word processors. [#19083](https://github.com/open-webui/open-webui/issues/19083)
|
|
||||||
|
|
||||||
### Changed
|
|
||||||
|
|
||||||
- 🗄️ Group membership data storage was refactored from JSON arrays to a dedicated relational database table, significantly improving query performance and scalability for instances with large numbers of users and groups, while API responses now return member counts instead of full user ID arrays. [#19239](https://github.com/open-webui/open-webui/pull/19239)
|
|
||||||
- 📄 MinerU parameter handling was refactored to pass parameters directly to the API, improving flexibility and fixing VLM backend configuration. [#19105](https://github.com/open-webui/open-webui/pull/19105), [#18446](https://github.com/open-webui/open-webui/discussions/18446)
|
|
||||||
- 🔐 API key creation is now controlled by granular user and group permissions, with the "ENABLE_API_KEY" environment variable renamed to "ENABLE_API_KEYS" and disabled by default, requiring explicit configuration at both the global and user permission levels, while related environment variables "ENABLE_API_KEY_ENDPOINT_RESTRICTIONS" and "API_KEY_ALLOWED_ENDPOINTS" were renamed to "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS" and "API_KEYS_ALLOWED_ENDPOINTS" respectively. [#18336](https://github.com/open-webui/open-webui/pull/18336)
|
|
||||||
|
|
||||||
## [0.6.36] - 2025-11-07
|
## [0.6.36] - 2025-11-07
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,6 @@ ARG USE_RERANKING_MODEL
|
||||||
ARG UID
|
ARG UID
|
||||||
ARG GID
|
ARG GID
|
||||||
|
|
||||||
# Python settings
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
## Basis ##
|
## Basis ##
|
||||||
ENV ENV=prod \
|
ENV ENV=prod \
|
||||||
PORT=8080 \
|
PORT=8080 \
|
||||||
|
|
|
||||||
35
INSTALLATION.md
Normal file
35
INSTALLATION.md
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
### Installing Both Ollama and Open WebUI Using Kustomize
|
||||||
|
|
||||||
|
For cpu-only pod
|
||||||
|
|
||||||
|
```bash
|
||||||
|
kubectl apply -f ./kubernetes/manifest/base
|
||||||
|
```
|
||||||
|
|
||||||
|
For gpu-enabled pod
|
||||||
|
|
||||||
|
```bash
|
||||||
|
kubectl apply -k ./kubernetes/manifest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installing Both Ollama and Open WebUI Using Helm
|
||||||
|
|
||||||
|
Package Helm file first
|
||||||
|
|
||||||
|
```bash
|
||||||
|
helm package ./kubernetes/helm/
|
||||||
|
```
|
||||||
|
|
||||||
|
For cpu-only pod
|
||||||
|
|
||||||
|
```bash
|
||||||
|
helm install ollama-webui ./ollama-webui-*.tgz
|
||||||
|
```
|
||||||
|
|
||||||
|
For gpu-enabled pod
|
||||||
|
|
||||||
|
```bash
|
||||||
|
helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
|
||||||
|
```
|
||||||
|
|
||||||
|
Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization
|
||||||
2
LICENSE
2
LICENSE
|
|
@ -1,4 +1,4 @@
|
||||||
Copyright (c) 2023- Open WebUI Inc. [Created by Timothy Jaeryang Baek]
|
Copyright (c) 2023-2025 Timothy Jaeryang Baek (Open WebUI)
|
||||||
All rights reserved.
|
All rights reserved.
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
|
|
||||||
36
README.md
36
README.md
|
|
@ -10,13 +10,11 @@
|
||||||
[](https://discord.gg/5rJgQTnV4s)
|
[](https://discord.gg/5rJgQTnV4s)
|
||||||
[](https://github.com/sponsors/tjbck)
|
[](https://github.com/sponsors/tjbck)
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
|
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
|
||||||
|
|
||||||
Passionate about open-source AI? [Join our team →](https://careers.openwebui.com/)
|
Passionate about open-source AI? [Join our team →](https://careers.openwebui.com/)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](https://docs.openwebui.com/enterprise)**
|
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](https://docs.openwebui.com/enterprise)**
|
||||||
|
|
@ -33,44 +31,32 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
|
||||||
|
|
||||||
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
||||||
|
|
||||||
|
- 🔄 **SCIM 2.0 Support**: Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
|
||||||
|
|
||||||
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
||||||
|
|
||||||
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
||||||
|
|
||||||
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
|
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
|
||||||
|
|
||||||
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features using multiple Speech-to-Text providers (Local Whisper, OpenAI, Deepgram, Azure) and Text-to-Speech engines (Azure, ElevenLabs, OpenAI, Transformers, WebAPI), allowing for dynamic and interactive chat environments.
|
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
|
||||||
|
|
||||||
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
|
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
|
||||||
|
|
||||||
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
|
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
|
||||||
|
|
||||||
- 💾 **Persistent Artifact Storage**: Built-in key-value storage API for artifacts, enabling features like journals, trackers, leaderboards, and collaborative tools with both personal and shared data scopes across sessions.
|
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
|
||||||
|
|
||||||
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support using your choice of 9 vector databases and multiple content extraction engines (Tika, Docling, Document Intelligence, Mistral OCR, External loaders). Load documents directly into chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
|
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
|
||||||
|
|
||||||
- 🔍 **Web Search for RAG**: Perform web searches using 15+ providers including `SearXNG`, `Google PSE`, `Brave Search`, `Kagi`, `Mojeek`, `Tavily`, `Perplexity`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `SearchApi`, `SerpApi`, `Bing`, `Jina`, `Exa`, `Sougou`, `Azure AI Search`, and `Ollama Cloud`, injecting results directly into your chat experience.
|
|
||||||
|
|
||||||
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
||||||
|
|
||||||
- 🎨 **Image Generation & Editing Integration**: Create and edit images using multiple engines including OpenAI's DALL-E, Gemini, ComfyUI (local), and AUTOMATIC1111 (local), with support for both generation and prompt-based editing workflows.
|
- 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
|
||||||
|
|
||||||
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
|
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
|
||||||
|
|
||||||
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
|
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
|
||||||
|
|
||||||
- 🗄️ **Flexible Database & Storage Options**: Choose from SQLite (with optional encryption), PostgreSQL, or configure cloud storage backends (S3, Google Cloud Storage, Azure Blob Storage) for scalable deployments.
|
|
||||||
|
|
||||||
- 🔍 **Advanced Vector Database Support**: Select from 9 vector database options including ChromaDB, PGVector, Qdrant, Milvus, Elasticsearch, OpenSearch, Pinecone, S3Vector, and Oracle 23ai for optimal RAG performance.
|
|
||||||
|
|
||||||
- 🔐 **Enterprise Authentication**: Full support for LDAP/Active Directory integration, SCIM 2.0 automated provisioning, and SSO via trusted headers alongside OAuth providers. Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
|
|
||||||
|
|
||||||
- ☁️ **Cloud-Native Integration**: Native support for Google Drive and OneDrive/SharePoint file picking, enabling seamless document import from enterprise cloud storage.
|
|
||||||
|
|
||||||
- 📊 **Production Observability**: Built-in OpenTelemetry support for traces, metrics, and logs, enabling comprehensive monitoring with your existing observability stack.
|
|
||||||
|
|
||||||
- ⚖️ **Horizontal Scalability**: Redis-backed session management and WebSocket support for multi-worker and multi-node deployments behind load balancers.
|
|
||||||
|
|
||||||
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
|
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
|
||||||
|
|
||||||
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
|
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
|
||||||
|
|
@ -190,6 +176,14 @@ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=
|
||||||
|
|
||||||
### Keeping Your Docker Installation Up-to-Date
|
### Keeping Your Docker Installation Up-to-Date
|
||||||
|
|
||||||
|
In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower --run-once open-webui
|
||||||
|
```
|
||||||
|
|
||||||
|
In the last part of the command, replace `open-webui` with your container name if it is different.
|
||||||
|
|
||||||
Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
|
Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
|
||||||
|
|
||||||
### Using the Dev Branch 🌙
|
### Using the Dev Branch 🌙
|
||||||
|
|
|
||||||
|
|
@ -287,30 +287,25 @@ class AppConfig:
|
||||||
# WEBUI_AUTH (Required for security)
|
# WEBUI_AUTH (Required for security)
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
ENABLE_API_KEYS = PersistentConfig(
|
ENABLE_API_KEY = PersistentConfig(
|
||||||
"ENABLE_API_KEYS",
|
"ENABLE_API_KEY",
|
||||||
"auth.enable_api_keys",
|
"auth.api_key.enable",
|
||||||
os.environ.get("ENABLE_API_KEYS", "False").lower() == "true",
|
os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = PersistentConfig(
|
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = PersistentConfig(
|
||||||
"ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS",
|
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS",
|
||||||
"auth.api_key.endpoint_restrictions",
|
"auth.api_key.endpoint_restrictions",
|
||||||
os.environ.get(
|
os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False").lower() == "true",
|
||||||
"ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS",
|
|
||||||
os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False"),
|
|
||||||
).lower()
|
|
||||||
== "true",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
API_KEYS_ALLOWED_ENDPOINTS = PersistentConfig(
|
API_KEY_ALLOWED_ENDPOINTS = PersistentConfig(
|
||||||
"API_KEYS_ALLOWED_ENDPOINTS",
|
"API_KEY_ALLOWED_ENDPOINTS",
|
||||||
"auth.api_key.allowed_endpoints",
|
"auth.api_key.allowed_endpoints",
|
||||||
os.environ.get(
|
os.environ.get("API_KEY_ALLOWED_ENDPOINTS", ""),
|
||||||
"API_KEYS_ALLOWED_ENDPOINTS", os.environ.get("API_KEY_ALLOWED_ENDPOINTS", "")
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
JWT_EXPIRES_IN = PersistentConfig(
|
JWT_EXPIRES_IN = PersistentConfig(
|
||||||
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "4w")
|
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "4w")
|
||||||
)
|
)
|
||||||
|
|
@ -583,16 +578,14 @@ OAUTH_ROLES_CLAIM = PersistentConfig(
|
||||||
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
|
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
|
||||||
)
|
)
|
||||||
|
|
||||||
OAUTH_ROLES_SEPARATOR = os.environ.get("OAUTH_ROLES_SEPARATOR", ",")
|
SEP = os.environ.get("OAUTH_ROLES_SEPARATOR", ",")
|
||||||
|
|
||||||
OAUTH_ALLOWED_ROLES = PersistentConfig(
|
OAUTH_ALLOWED_ROLES = PersistentConfig(
|
||||||
"OAUTH_ALLOWED_ROLES",
|
"OAUTH_ALLOWED_ROLES",
|
||||||
"oauth.allowed_roles",
|
"oauth.allowed_roles",
|
||||||
[
|
[
|
||||||
role.strip()
|
role.strip()
|
||||||
for role in os.environ.get(
|
for role in os.environ.get("OAUTH_ALLOWED_ROLES", f"user{SEP}admin").split(SEP)
|
||||||
"OAUTH_ALLOWED_ROLES", f"user{OAUTH_ROLES_SEPARATOR}admin"
|
|
||||||
).split(OAUTH_ROLES_SEPARATOR)
|
|
||||||
if role
|
if role
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -602,9 +595,7 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
|
||||||
"oauth.admin_roles",
|
"oauth.admin_roles",
|
||||||
[
|
[
|
||||||
role.strip()
|
role.strip()
|
||||||
for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(
|
for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(SEP)
|
||||||
OAUTH_ROLES_SEPARATOR
|
|
||||||
)
|
|
||||||
if role
|
if role
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -624,17 +615,6 @@ OAUTH_UPDATE_PICTURE_ON_LOGIN = PersistentConfig(
|
||||||
os.environ.get("OAUTH_UPDATE_PICTURE_ON_LOGIN", "False").lower() == "true",
|
os.environ.get("OAUTH_UPDATE_PICTURE_ON_LOGIN", "False").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID = (
|
|
||||||
os.environ.get("OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
OAUTH_AUDIENCE = PersistentConfig(
|
|
||||||
"OAUTH_AUDIENCE",
|
|
||||||
"oauth.audience",
|
|
||||||
os.environ.get("OAUTH_AUDIENCE", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_oauth_providers():
|
def load_oauth_providers():
|
||||||
OAUTH_PROVIDERS.clear()
|
OAUTH_PROVIDERS.clear()
|
||||||
|
|
@ -1144,7 +1124,6 @@ ENABLE_LOGIN_FORM = PersistentConfig(
|
||||||
os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
|
os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_PASSWORD_AUTH = os.environ.get("ENABLE_PASSWORD_AUTH", "True").lower() == "true"
|
|
||||||
|
|
||||||
DEFAULT_LOCALE = PersistentConfig(
|
DEFAULT_LOCALE = PersistentConfig(
|
||||||
"DEFAULT_LOCALE",
|
"DEFAULT_LOCALE",
|
||||||
|
|
@ -1156,12 +1135,6 @@ DEFAULT_MODELS = PersistentConfig(
|
||||||
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
|
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_PINNED_MODELS = PersistentConfig(
|
|
||||||
"DEFAULT_PINNED_MODELS",
|
|
||||||
"ui.default_pinned_models",
|
|
||||||
os.environ.get("DEFAULT_PINNED_MODELS", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
default_prompt_suggestions = json.loads(
|
default_prompt_suggestions = json.loads(
|
||||||
os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]")
|
os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]")
|
||||||
|
|
@ -1218,12 +1191,6 @@ DEFAULT_USER_ROLE = PersistentConfig(
|
||||||
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_GROUP_ID = PersistentConfig(
|
|
||||||
"DEFAULT_GROUP_ID",
|
|
||||||
"ui.default_group_id",
|
|
||||||
os.environ.get("DEFAULT_GROUP_ID", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
PENDING_USER_OVERLAY_TITLE = PersistentConfig(
|
PENDING_USER_OVERLAY_TITLE = PersistentConfig(
|
||||||
"PENDING_USER_OVERLAY_TITLE",
|
"PENDING_USER_OVERLAY_TITLE",
|
||||||
"ui.pending_user_overlay_title",
|
"ui.pending_user_overlay_title",
|
||||||
|
|
@ -1263,40 +1230,6 @@ USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
|
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = (
|
USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = (
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", "False"
|
"USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", "False"
|
||||||
|
|
@ -1304,10 +1237,8 @@ USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING = (
|
||||||
== "true"
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING = (
|
USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = (
|
||||||
os.environ.get(
|
os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower()
|
||||||
"USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING", "False"
|
|
||||||
).lower()
|
|
||||||
== "true"
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1318,11 +1249,6 @@ USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING = (
|
||||||
== "true"
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = (
|
USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = (
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING", "False"
|
"USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING", "False"
|
||||||
|
|
@ -1330,12 +1256,6 @@ USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING = (
|
||||||
== "true"
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = (
|
USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = (
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
"USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING", "False"
|
"USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING", "False"
|
||||||
|
|
@ -1344,16 +1264,6 @@ USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
USER_PERMISSIONS_NOTES_ALLOW_SHARING = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_SHARING", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
USER_PERMISSIONS_CHAT_CONTROLS = (
|
USER_PERMISSIONS_CHAT_CONTROLS = (
|
||||||
os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true"
|
os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -1452,22 +1362,10 @@ USER_PERMISSIONS_FEATURES_CODE_INTERPRETER = (
|
||||||
== "true"
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS_FEATURES_FOLDERS = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_FEATURES_FOLDERS", "True").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_FEATURES_NOTES = (
|
USER_PERMISSIONS_FEATURES_NOTES = (
|
||||||
os.environ.get("USER_PERMISSIONS_FEATURES_NOTES", "True").lower() == "true"
|
os.environ.get("USER_PERMISSIONS_FEATURES_NOTES", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS_FEATURES_CHANNELS = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_FEATURES_CHANNELS", "True").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
USER_PERMISSIONS_FEATURES_API_KEYS = (
|
|
||||||
os.environ.get("USER_PERMISSIONS_FEATURES_API_KEYS", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_USER_PERMISSIONS = {
|
DEFAULT_USER_PERMISSIONS = {
|
||||||
"workspace": {
|
"workspace": {
|
||||||
|
|
@ -1475,23 +1373,12 @@ DEFAULT_USER_PERMISSIONS = {
|
||||||
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
|
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
|
||||||
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
|
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
|
||||||
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
|
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
|
||||||
"models_import": USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT,
|
|
||||||
"models_export": USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT,
|
|
||||||
"prompts_import": USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT,
|
|
||||||
"prompts_export": USER_PERMISSIONS_WORKSPACE_PROMPTS_EXPORT,
|
|
||||||
"tools_import": USER_PERMISSIONS_WORKSPACE_TOOLS_IMPORT,
|
|
||||||
"tools_export": USER_PERMISSIONS_WORKSPACE_TOOLS_EXPORT,
|
|
||||||
},
|
},
|
||||||
"sharing": {
|
"sharing": {
|
||||||
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING,
|
|
||||||
"public_models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING,
|
"public_models": USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING,
|
||||||
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_SHARING,
|
|
||||||
"public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING,
|
"public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING,
|
||||||
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_SHARING,
|
|
||||||
"public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING,
|
"public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING,
|
||||||
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_SHARING,
|
|
||||||
"public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING,
|
"public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING,
|
||||||
"notes": USER_PERMISSIONS_NOTES_ALLOW_SHARING,
|
|
||||||
"public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING,
|
"public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING,
|
||||||
},
|
},
|
||||||
"chat": {
|
"chat": {
|
||||||
|
|
@ -1516,16 +1403,11 @@ DEFAULT_USER_PERMISSIONS = {
|
||||||
"temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED,
|
"temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED,
|
||||||
},
|
},
|
||||||
"features": {
|
"features": {
|
||||||
# General features
|
|
||||||
"api_keys": USER_PERMISSIONS_FEATURES_API_KEYS,
|
|
||||||
"notes": USER_PERMISSIONS_FEATURES_NOTES,
|
|
||||||
"folders": USER_PERMISSIONS_FEATURES_FOLDERS,
|
|
||||||
"channels": USER_PERMISSIONS_FEATURES_CHANNELS,
|
|
||||||
"direct_tool_servers": USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS,
|
"direct_tool_servers": USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS,
|
||||||
# Chat features
|
|
||||||
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
|
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
|
||||||
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
|
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
|
||||||
"code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER,
|
"code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER,
|
||||||
|
"notes": USER_PERMISSIONS_FEATURES_NOTES,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1535,12 +1417,6 @@ USER_PERMISSIONS = PersistentConfig(
|
||||||
DEFAULT_USER_PERMISSIONS,
|
DEFAULT_USER_PERMISSIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_FOLDERS = PersistentConfig(
|
|
||||||
"ENABLE_FOLDERS",
|
|
||||||
"folders.enable",
|
|
||||||
os.environ.get("ENABLE_FOLDERS", "True").lower() == "true",
|
|
||||||
)
|
|
||||||
|
|
||||||
ENABLE_CHANNELS = PersistentConfig(
|
ENABLE_CHANNELS = PersistentConfig(
|
||||||
"ENABLE_CHANNELS",
|
"ENABLE_CHANNELS",
|
||||||
"channels.enable",
|
"channels.enable",
|
||||||
|
|
@ -1940,38 +1816,6 @@ Output:
|
||||||
#### Output:
|
#### Output:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
VOICE_MODE_PROMPT_TEMPLATE = PersistentConfig(
|
|
||||||
"VOICE_MODE_PROMPT_TEMPLATE",
|
|
||||||
"task.voice.prompt_template",
|
|
||||||
os.environ.get("VOICE_MODE_PROMPT_TEMPLATE", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE = """You are a friendly, concise voice assistant.
|
|
||||||
|
|
||||||
Everything you say will be spoken aloud.
|
|
||||||
Keep responses short, clear, and natural.
|
|
||||||
|
|
||||||
STYLE:
|
|
||||||
- Use simple words and short sentences.
|
|
||||||
- Sound warm and conversational.
|
|
||||||
- Avoid long explanations, lists, or complex phrasing.
|
|
||||||
|
|
||||||
BEHAVIOR:
|
|
||||||
- Give the quickest helpful answer first.
|
|
||||||
- Offer extra detail only if needed.
|
|
||||||
- Ask for clarification only when necessary.
|
|
||||||
|
|
||||||
VOICE OPTIMIZATION:
|
|
||||||
- Break information into small, easy-to-hear chunks.
|
|
||||||
- Avoid dense wording or anything that sounds like reading text.
|
|
||||||
|
|
||||||
ERROR HANDLING:
|
|
||||||
- If unsure, say so briefly and offer options.
|
|
||||||
- If something is unsafe or impossible, decline kindly and suggest a safe alternative.
|
|
||||||
|
|
||||||
Stay consistent, helpful, and easy to listen to."""
|
|
||||||
|
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||||
"task.tools.prompt_template",
|
"task.tools.prompt_template",
|
||||||
|
|
@ -2212,11 +2056,6 @@ ENABLE_QDRANT_MULTITENANCY_MODE = (
|
||||||
)
|
)
|
||||||
QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
|
QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
|
||||||
|
|
||||||
WEAVIATE_HTTP_HOST = os.environ.get("WEAVIATE_HTTP_HOST", "")
|
|
||||||
WEAVIATE_HTTP_PORT = int(os.environ.get("WEAVIATE_HTTP_PORT", "8080"))
|
|
||||||
WEAVIATE_GRPC_PORT = int(os.environ.get("WEAVIATE_GRPC_PORT", "50051"))
|
|
||||||
WEAVIATE_API_KEY = os.environ.get("WEAVIATE_API_KEY")
|
|
||||||
|
|
||||||
# OpenSearch
|
# OpenSearch
|
||||||
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
|
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
|
||||||
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true"
|
OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true"
|
||||||
|
|
@ -2247,16 +2086,6 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
||||||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||||
)
|
)
|
||||||
|
|
||||||
PGVECTOR_USE_HALFVEC = os.getenv("PGVECTOR_USE_HALFVEC", "false").lower() == "true"
|
|
||||||
|
|
||||||
if PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH > 2000 and not PGVECTOR_USE_HALFVEC:
|
|
||||||
raise ValueError(
|
|
||||||
"PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH is set to "
|
|
||||||
f"{PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH}, which exceeds the 2000 dimension limit of the "
|
|
||||||
"'vector' type. Set PGVECTOR_USE_HALFVEC=true to enable the 'halfvec' "
|
|
||||||
"type required for high-dimensional embeddings."
|
|
||||||
)
|
|
||||||
|
|
||||||
PGVECTOR_CREATE_EXTENSION = (
|
PGVECTOR_CREATE_EXTENSION = (
|
||||||
os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true"
|
os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -2306,40 +2135,6 @@ else:
|
||||||
except Exception:
|
except Exception:
|
||||||
PGVECTOR_POOL_RECYCLE = 3600
|
PGVECTOR_POOL_RECYCLE = 3600
|
||||||
|
|
||||||
PGVECTOR_INDEX_METHOD = os.getenv("PGVECTOR_INDEX_METHOD", "").strip().lower()
|
|
||||||
if PGVECTOR_INDEX_METHOD not in ("ivfflat", "hnsw", ""):
|
|
||||||
PGVECTOR_INDEX_METHOD = ""
|
|
||||||
|
|
||||||
PGVECTOR_HNSW_M = os.environ.get("PGVECTOR_HNSW_M", 16)
|
|
||||||
|
|
||||||
if PGVECTOR_HNSW_M == "":
|
|
||||||
PGVECTOR_HNSW_M = 16
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
PGVECTOR_HNSW_M = int(PGVECTOR_HNSW_M)
|
|
||||||
except Exception:
|
|
||||||
PGVECTOR_HNSW_M = 16
|
|
||||||
|
|
||||||
PGVECTOR_HNSW_EF_CONSTRUCTION = os.environ.get("PGVECTOR_HNSW_EF_CONSTRUCTION", 64)
|
|
||||||
|
|
||||||
if PGVECTOR_HNSW_EF_CONSTRUCTION == "":
|
|
||||||
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
PGVECTOR_HNSW_EF_CONSTRUCTION = int(PGVECTOR_HNSW_EF_CONSTRUCTION)
|
|
||||||
except Exception:
|
|
||||||
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
|
|
||||||
|
|
||||||
PGVECTOR_IVFFLAT_LISTS = os.environ.get("PGVECTOR_IVFFLAT_LISTS", 100)
|
|
||||||
|
|
||||||
if PGVECTOR_IVFFLAT_LISTS == "":
|
|
||||||
PGVECTOR_IVFFLAT_LISTS = 100
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
PGVECTOR_IVFFLAT_LISTS = int(PGVECTOR_IVFFLAT_LISTS)
|
|
||||||
except Exception:
|
|
||||||
PGVECTOR_IVFFLAT_LISTS = 100
|
|
||||||
|
|
||||||
# Pinecone
|
# Pinecone
|
||||||
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
||||||
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
||||||
|
|
@ -2523,12 +2318,6 @@ MINERU_API_URL = PersistentConfig(
|
||||||
os.environ.get("MINERU_API_URL", "http://localhost:8000"),
|
os.environ.get("MINERU_API_URL", "http://localhost:8000"),
|
||||||
)
|
)
|
||||||
|
|
||||||
MINERU_API_TIMEOUT = PersistentConfig(
|
|
||||||
"MINERU_API_TIMEOUT",
|
|
||||||
"rag.mineru_api_timeout",
|
|
||||||
os.environ.get("MINERU_API_TIMEOUT", "300"),
|
|
||||||
)
|
|
||||||
|
|
||||||
MINERU_API_KEY = PersistentConfig(
|
MINERU_API_KEY = PersistentConfig(
|
||||||
"MINERU_API_KEY",
|
"MINERU_API_KEY",
|
||||||
"rag.mineru_api_key",
|
"rag.mineru_api_key",
|
||||||
|
|
@ -2571,12 +2360,6 @@ DOCLING_SERVER_URL = PersistentConfig(
|
||||||
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
|
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
|
||||||
)
|
)
|
||||||
|
|
||||||
DOCLING_API_KEY = PersistentConfig(
|
|
||||||
"DOCLING_API_KEY",
|
|
||||||
"rag.docling_api_key",
|
|
||||||
os.getenv("DOCLING_API_KEY", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
docling_params = os.getenv("DOCLING_PARAMS", "")
|
docling_params = os.getenv("DOCLING_PARAMS", "")
|
||||||
try:
|
try:
|
||||||
docling_params = json.loads(docling_params)
|
docling_params = json.loads(docling_params)
|
||||||
|
|
@ -2589,6 +2372,88 @@ DOCLING_PARAMS = PersistentConfig(
|
||||||
docling_params,
|
docling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DOCLING_DO_OCR = PersistentConfig(
|
||||||
|
"DOCLING_DO_OCR",
|
||||||
|
"rag.docling_do_ocr",
|
||||||
|
os.getenv("DOCLING_DO_OCR", "True").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_FORCE_OCR = PersistentConfig(
|
||||||
|
"DOCLING_FORCE_OCR",
|
||||||
|
"rag.docling_force_ocr",
|
||||||
|
os.getenv("DOCLING_FORCE_OCR", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_OCR_ENGINE = PersistentConfig(
|
||||||
|
"DOCLING_OCR_ENGINE",
|
||||||
|
"rag.docling_ocr_engine",
|
||||||
|
os.getenv("DOCLING_OCR_ENGINE", "tesseract"),
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_OCR_LANG = PersistentConfig(
|
||||||
|
"DOCLING_OCR_LANG",
|
||||||
|
"rag.docling_ocr_lang",
|
||||||
|
os.getenv("DOCLING_OCR_LANG", "eng,fra,deu,spa"),
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_PDF_BACKEND = PersistentConfig(
|
||||||
|
"DOCLING_PDF_BACKEND",
|
||||||
|
"rag.docling_pdf_backend",
|
||||||
|
os.getenv("DOCLING_PDF_BACKEND", "dlparse_v4"),
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_TABLE_MODE = PersistentConfig(
|
||||||
|
"DOCLING_TABLE_MODE",
|
||||||
|
"rag.docling_table_mode",
|
||||||
|
os.getenv("DOCLING_TABLE_MODE", "accurate"),
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_PIPELINE = PersistentConfig(
|
||||||
|
"DOCLING_PIPELINE",
|
||||||
|
"rag.docling_pipeline",
|
||||||
|
os.getenv("DOCLING_PIPELINE", "standard"),
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_DO_PICTURE_DESCRIPTION = PersistentConfig(
|
||||||
|
"DOCLING_DO_PICTURE_DESCRIPTION",
|
||||||
|
"rag.docling_do_picture_description",
|
||||||
|
os.getenv("DOCLING_DO_PICTURE_DESCRIPTION", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCLING_PICTURE_DESCRIPTION_MODE = PersistentConfig(
|
||||||
|
"DOCLING_PICTURE_DESCRIPTION_MODE",
|
||||||
|
"rag.docling_picture_description_mode",
|
||||||
|
os.getenv("DOCLING_PICTURE_DESCRIPTION_MODE", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
docling_picture_description_local = os.getenv("DOCLING_PICTURE_DESCRIPTION_LOCAL", "")
|
||||||
|
try:
|
||||||
|
docling_picture_description_local = json.loads(docling_picture_description_local)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
docling_picture_description_local = {}
|
||||||
|
|
||||||
|
|
||||||
|
DOCLING_PICTURE_DESCRIPTION_LOCAL = PersistentConfig(
|
||||||
|
"DOCLING_PICTURE_DESCRIPTION_LOCAL",
|
||||||
|
"rag.docling_picture_description_local",
|
||||||
|
docling_picture_description_local,
|
||||||
|
)
|
||||||
|
|
||||||
|
docling_picture_description_api = os.getenv("DOCLING_PICTURE_DESCRIPTION_API", "")
|
||||||
|
try:
|
||||||
|
docling_picture_description_api = json.loads(docling_picture_description_api)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
docling_picture_description_api = {}
|
||||||
|
|
||||||
|
|
||||||
|
DOCLING_PICTURE_DESCRIPTION_API = PersistentConfig(
|
||||||
|
"DOCLING_PICTURE_DESCRIPTION_API",
|
||||||
|
"rag.docling_picture_description_api",
|
||||||
|
docling_picture_description_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||||
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||||
"rag.document_intelligence_endpoint",
|
"rag.document_intelligence_endpoint",
|
||||||
|
|
@ -2601,12 +2466,6 @@ DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
|
||||||
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
|
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
DOCUMENT_INTELLIGENCE_MODEL = PersistentConfig(
|
|
||||||
"DOCUMENT_INTELLIGENCE_MODEL",
|
|
||||||
"rag.document_intelligence_model",
|
|
||||||
os.getenv("DOCUMENT_INTELLIGENCE_MODEL", "prebuilt-layout"),
|
|
||||||
)
|
|
||||||
|
|
||||||
MISTRAL_OCR_API_BASE_URL = PersistentConfig(
|
MISTRAL_OCR_API_BASE_URL = PersistentConfig(
|
||||||
"MISTRAL_OCR_API_BASE_URL",
|
"MISTRAL_OCR_API_BASE_URL",
|
||||||
"rag.MISTRAL_OCR_API_BASE_URL",
|
"rag.MISTRAL_OCR_API_BASE_URL",
|
||||||
|
|
@ -2651,13 +2510,6 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
|
||||||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = PersistentConfig(
|
|
||||||
"ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS",
|
|
||||||
"rag.enable_hybrid_search_enriched_texts",
|
|
||||||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS", "False").lower()
|
|
||||||
== "true",
|
|
||||||
)
|
|
||||||
|
|
||||||
RAG_FULL_CONTEXT = PersistentConfig(
|
RAG_FULL_CONTEXT = PersistentConfig(
|
||||||
"RAG_FULL_CONTEXT",
|
"RAG_FULL_CONTEXT",
|
||||||
"rag.full_context",
|
"rag.full_context",
|
||||||
|
|
@ -2752,12 +2604,6 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_ASYNC_EMBEDDING = PersistentConfig(
|
|
||||||
"ENABLE_ASYNC_EMBEDDING",
|
|
||||||
"rag.enable_async_embedding",
|
|
||||||
os.environ.get("ENABLE_ASYNC_EMBEDDING", "True").lower() == "true",
|
|
||||||
)
|
|
||||||
|
|
||||||
RAG_EMBEDDING_QUERY_PREFIX = os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None)
|
RAG_EMBEDDING_QUERY_PREFIX = os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None)
|
||||||
|
|
||||||
RAG_EMBEDDING_CONTENT_PREFIX = os.environ.get("RAG_EMBEDDING_CONTENT_PREFIX", None)
|
RAG_EMBEDDING_CONTENT_PREFIX = os.environ.get("RAG_EMBEDDING_CONTENT_PREFIX", None)
|
||||||
|
|
@ -2802,12 +2648,6 @@ RAG_EXTERNAL_RERANKER_API_KEY = PersistentConfig(
|
||||||
os.environ.get("RAG_EXTERNAL_RERANKER_API_KEY", ""),
|
os.environ.get("RAG_EXTERNAL_RERANKER_API_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
RAG_EXTERNAL_RERANKER_TIMEOUT = PersistentConfig(
|
|
||||||
"RAG_EXTERNAL_RERANKER_TIMEOUT",
|
|
||||||
"rag.external_reranker_timeout",
|
|
||||||
os.environ.get("RAG_EXTERNAL_RERANKER_TIMEOUT", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
RAG_TEXT_SPLITTER = PersistentConfig(
|
RAG_TEXT_SPLITTER = PersistentConfig(
|
||||||
"RAG_TEXT_SPLITTER",
|
"RAG_TEXT_SPLITTER",
|
||||||
|
|
@ -2857,6 +2697,10 @@ Provide a clear and direct response to the user's query, including inline citati
|
||||||
<context>
|
<context>
|
||||||
{{CONTEXT}}
|
{{CONTEXT}}
|
||||||
</context>
|
</context>
|
||||||
|
|
||||||
|
<user_query>
|
||||||
|
{{QUERY}}
|
||||||
|
</user_query>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
RAG_TEMPLATE = PersistentConfig(
|
RAG_TEMPLATE = PersistentConfig(
|
||||||
|
|
@ -2909,26 +2753,6 @@ ENABLE_RAG_LOCAL_WEB_FETCH = (
|
||||||
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_WEB_FETCH_FILTER_LIST = [
|
|
||||||
"!169.254.169.254",
|
|
||||||
"!fd00:ec2::254",
|
|
||||||
"!metadata.google.internal",
|
|
||||||
"!metadata.azure.com",
|
|
||||||
"!100.100.100.200",
|
|
||||||
]
|
|
||||||
|
|
||||||
web_fetch_filter_list = os.getenv("WEB_FETCH_FILTER_LIST", "")
|
|
||||||
if web_fetch_filter_list == "":
|
|
||||||
web_fetch_filter_list = []
|
|
||||||
else:
|
|
||||||
web_fetch_filter_list = [
|
|
||||||
item.strip() for item in web_fetch_filter_list.split(",") if item.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
WEB_FETCH_FILTER_LIST = list(set(DEFAULT_WEB_FETCH_FILTER_LIST + web_fetch_filter_list))
|
|
||||||
|
|
||||||
|
|
||||||
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
||||||
"YOUTUBE_LOADER_LANGUAGE",
|
"YOUTUBE_LOADER_LANGUAGE",
|
||||||
"rag.youtube_loader_language",
|
"rag.youtube_loader_language",
|
||||||
|
|
@ -2987,14 +2811,13 @@ WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||||
# "wikipedia.com",
|
# "wikipedia.com",
|
||||||
# "wikimedia.org",
|
# "wikimedia.org",
|
||||||
# "wikidata.org",
|
# "wikidata.org",
|
||||||
# "!stackoverflow.com",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||||
"WEB_SEARCH_CONCURRENT_REQUESTS",
|
"WEB_SEARCH_CONCURRENT_REQUESTS",
|
||||||
"rag.web.search.concurrent_requests",
|
"rag.web.search.concurrent_requests",
|
||||||
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "0")),
|
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3011,12 +2834,6 @@ WEB_LOADER_CONCURRENT_REQUESTS = PersistentConfig(
|
||||||
int(os.getenv("WEB_LOADER_CONCURRENT_REQUESTS", "10")),
|
int(os.getenv("WEB_LOADER_CONCURRENT_REQUESTS", "10")),
|
||||||
)
|
)
|
||||||
|
|
||||||
WEB_LOADER_TIMEOUT = PersistentConfig(
|
|
||||||
"WEB_LOADER_TIMEOUT",
|
|
||||||
"rag.web.loader.timeout",
|
|
||||||
os.getenv("WEB_LOADER_TIMEOUT", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
|
ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
|
||||||
"ENABLE_WEB_LOADER_SSL_VERIFICATION",
|
"ENABLE_WEB_LOADER_SSL_VERIFICATION",
|
||||||
|
|
@ -3043,12 +2860,6 @@ SEARXNG_QUERY_URL = PersistentConfig(
|
||||||
os.getenv("SEARXNG_QUERY_URL", ""),
|
os.getenv("SEARXNG_QUERY_URL", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
SEARXNG_LANGUAGE = PersistentConfig(
|
|
||||||
"SEARXNG_LANGUAGE",
|
|
||||||
"rag.web.search.searxng_language",
|
|
||||||
os.getenv("SEARXNG_LANGUAGE", "all"),
|
|
||||||
)
|
|
||||||
|
|
||||||
YACY_QUERY_URL = PersistentConfig(
|
YACY_QUERY_URL = PersistentConfig(
|
||||||
"YACY_QUERY_URL",
|
"YACY_QUERY_URL",
|
||||||
"rag.web.search.yacy_query_url",
|
"rag.web.search.yacy_query_url",
|
||||||
|
|
@ -3171,24 +2982,6 @@ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
|
||||||
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
|
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
AZURE_AI_SEARCH_API_KEY = PersistentConfig(
|
|
||||||
"AZURE_AI_SEARCH_API_KEY",
|
|
||||||
"rag.web.search.azure_ai_search_api_key",
|
|
||||||
os.environ.get("AZURE_AI_SEARCH_API_KEY", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
AZURE_AI_SEARCH_ENDPOINT = PersistentConfig(
|
|
||||||
"AZURE_AI_SEARCH_ENDPOINT",
|
|
||||||
"rag.web.search.azure_ai_search_endpoint",
|
|
||||||
os.environ.get("AZURE_AI_SEARCH_ENDPOINT", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
AZURE_AI_SEARCH_INDEX_NAME = PersistentConfig(
|
|
||||||
"AZURE_AI_SEARCH_INDEX_NAME",
|
|
||||||
"rag.web.search.azure_ai_search_index_name",
|
|
||||||
os.environ.get("AZURE_AI_SEARCH_INDEX_NAME", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
EXA_API_KEY = PersistentConfig(
|
EXA_API_KEY = PersistentConfig(
|
||||||
"EXA_API_KEY",
|
"EXA_API_KEY",
|
||||||
"rag.web.search.exa_api_key",
|
"rag.web.search.exa_api_key",
|
||||||
|
|
@ -3213,12 +3006,6 @@ PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig(
|
||||||
os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
|
os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
|
||||||
)
|
)
|
||||||
|
|
||||||
PERPLEXITY_SEARCH_API_URL = PersistentConfig(
|
|
||||||
"PERPLEXITY_SEARCH_API_URL",
|
|
||||||
"rag.web.search.perplexity_search_api_url",
|
|
||||||
os.getenv("PERPLEXITY_SEARCH_API_URL", "https://api.perplexity.ai/search"),
|
|
||||||
)
|
|
||||||
|
|
||||||
SOUGOU_API_SID = PersistentConfig(
|
SOUGOU_API_SID = PersistentConfig(
|
||||||
"SOUGOU_API_SID",
|
"SOUGOU_API_SID",
|
||||||
"rag.web.search.sougou_api_sid",
|
"rag.web.search.sougou_api_sid",
|
||||||
|
|
@ -3344,9 +3131,10 @@ try:
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
automatic1111_params = {}
|
automatic1111_params = {}
|
||||||
|
|
||||||
|
|
||||||
AUTOMATIC1111_PARAMS = PersistentConfig(
|
AUTOMATIC1111_PARAMS = PersistentConfig(
|
||||||
"AUTOMATIC1111_PARAMS",
|
"AUTOMATIC1111_PARAMS",
|
||||||
"image_generation.automatic1111.api_params",
|
"image_generation.automatic1111.api_auth",
|
||||||
automatic1111_params,
|
automatic1111_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -3479,16 +3267,10 @@ COMFYUI_WORKFLOW = PersistentConfig(
|
||||||
os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW),
|
os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW),
|
||||||
)
|
)
|
||||||
|
|
||||||
comfyui_workflow_nodes = os.getenv("COMFYUI_WORKFLOW_NODES", "")
|
|
||||||
try:
|
|
||||||
comfyui_workflow_nodes = json.loads(comfyui_workflow_nodes)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
comfyui_workflow_nodes = []
|
|
||||||
|
|
||||||
COMFYUI_WORKFLOW_NODES = PersistentConfig(
|
COMFYUI_WORKFLOW_NODES = PersistentConfig(
|
||||||
"COMFYUI_WORKFLOW_NODES",
|
"COMFYUI_WORKFLOW",
|
||||||
"image_generation.comfyui.nodes",
|
"image_generation.comfyui.nodes",
|
||||||
comfyui_workflow_nodes,
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
|
IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
|
||||||
|
|
@ -3508,18 +3290,6 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
|
||||||
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||||
)
|
)
|
||||||
|
|
||||||
images_openai_params = os.getenv("IMAGES_OPENAI_PARAMS", "")
|
|
||||||
try:
|
|
||||||
images_openai_params = json.loads(images_openai_params)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
images_openai_params = {}
|
|
||||||
|
|
||||||
|
|
||||||
IMAGES_OPENAI_API_PARAMS = PersistentConfig(
|
|
||||||
"IMAGES_OPENAI_API_PARAMS", "image_generation.openai.params", images_openai_params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
|
IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
|
||||||
"IMAGES_GEMINI_API_BASE_URL",
|
"IMAGES_GEMINI_API_BASE_URL",
|
||||||
"image_generation.gemini.api_base_url",
|
"image_generation.gemini.api_base_url",
|
||||||
|
|
@ -3537,11 +3307,6 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig(
|
||||||
os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
|
os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_IMAGE_EDIT = PersistentConfig(
|
|
||||||
"ENABLE_IMAGE_EDIT",
|
|
||||||
"images.edit.enable",
|
|
||||||
os.environ.get("ENABLE_IMAGE_EDIT", "").lower() == "true",
|
|
||||||
)
|
|
||||||
|
|
||||||
IMAGE_EDIT_ENGINE = PersistentConfig(
|
IMAGE_EDIT_ENGINE = PersistentConfig(
|
||||||
"IMAGE_EDIT_ENGINE",
|
"IMAGE_EDIT_ENGINE",
|
||||||
|
|
@ -3605,16 +3370,10 @@ IMAGES_EDIT_COMFYUI_WORKFLOW = PersistentConfig(
|
||||||
os.getenv("IMAGES_EDIT_COMFYUI_WORKFLOW", ""),
|
os.getenv("IMAGES_EDIT_COMFYUI_WORKFLOW", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
images_edit_comfyui_workflow_nodes = os.getenv("IMAGES_EDIT_COMFYUI_WORKFLOW_NODES", "")
|
|
||||||
try:
|
|
||||||
images_edit_comfyui_workflow_nodes = json.loads(images_edit_comfyui_workflow_nodes)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
images_edit_comfyui_workflow_nodes = []
|
|
||||||
|
|
||||||
IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = PersistentConfig(
|
IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = PersistentConfig(
|
||||||
"IMAGES_EDIT_COMFYUI_WORKFLOW_NODES",
|
"IMAGES_EDIT_COMFYUI_WORKFLOW_NODES",
|
||||||
"images.edit.comfyui.nodes",
|
"images.edit.comfyui.nodes",
|
||||||
images_edit_comfyui_workflow_nodes,
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
)
|
)
|
||||||
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
||||||
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
||||||
INCORRECT_PASSWORD = (
|
INVALID_PASSWORD = (
|
||||||
"The password provided is incorrect. Please check for typos and try again."
|
"The password provided is incorrect. Please check for typos and try again."
|
||||||
)
|
)
|
||||||
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
||||||
|
|
@ -105,10 +105,6 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
)
|
)
|
||||||
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
||||||
|
|
||||||
INVALID_PASSWORD = lambda err="": (
|
|
||||||
err if err else "The password does not meet the required validation criteria."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TASKS(str, Enum):
|
class TASKS(str, Enum):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,6 @@ import shutil
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
@ -85,7 +83,32 @@ if "cuda_error" in locals():
|
||||||
log.exception(cuda_error)
|
log.exception(cuda_error)
|
||||||
del cuda_error
|
del cuda_error
|
||||||
|
|
||||||
SRC_LOG_LEVELS = {} # Legacy variable, do not remove
|
log_sources = [
|
||||||
|
"AUDIO",
|
||||||
|
"COMFYUI",
|
||||||
|
"CONFIG",
|
||||||
|
"DB",
|
||||||
|
"IMAGES",
|
||||||
|
"MAIN",
|
||||||
|
"MODELS",
|
||||||
|
"OLLAMA",
|
||||||
|
"OPENAI",
|
||||||
|
"RAG",
|
||||||
|
"WEBHOOK",
|
||||||
|
"SOCKET",
|
||||||
|
"OAUTH",
|
||||||
|
]
|
||||||
|
|
||||||
|
SRC_LOG_LEVELS = {}
|
||||||
|
|
||||||
|
for source in log_sources:
|
||||||
|
log_env_var = source + "_LOG_LEVEL"
|
||||||
|
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||||
|
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
|
||||||
|
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||||
|
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||||
|
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||||
|
|
||||||
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||||
if WEBUI_NAME != "Open WebUI":
|
if WEBUI_NAME != "Open WebUI":
|
||||||
|
|
@ -112,9 +135,6 @@ else:
|
||||||
PACKAGE_DATA = {"version": "0.0.0"}
|
PACKAGE_DATA = {"version": "0.0.0"}
|
||||||
|
|
||||||
VERSION = PACKAGE_DATA["version"]
|
VERSION = PACKAGE_DATA["version"]
|
||||||
|
|
||||||
|
|
||||||
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "")
|
|
||||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
|
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -339,11 +359,6 @@ if DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL is not None:
|
||||||
except Exception:
|
except Exception:
|
||||||
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0
|
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0
|
||||||
|
|
||||||
# Enable public visibility of active user count (when disabled, only admins can see it)
|
|
||||||
ENABLE_PUBLIC_ACTIVE_USERS_COUNT = (
|
|
||||||
os.environ.get("ENABLE_PUBLIC_ACTIVE_USERS_COUNT", "True").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
RESET_CONFIG_ON_START = (
|
RESET_CONFIG_ON_START = (
|
||||||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -375,13 +390,6 @@ try:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
|
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
|
||||||
|
|
||||||
|
|
||||||
REDIS_SOCKET_CONNECT_TIMEOUT = os.environ.get("REDIS_SOCKET_CONNECT_TIMEOUT", "")
|
|
||||||
try:
|
|
||||||
REDIS_SOCKET_CONNECT_TIMEOUT = float(REDIS_SOCKET_CONNECT_TIMEOUT)
|
|
||||||
except ValueError:
|
|
||||||
REDIS_SOCKET_CONNECT_TIMEOUT = None
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# UVICORN WORKERS
|
# UVICORN WORKERS
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -418,23 +426,6 @@ WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ENABLE_PASSWORD_VALIDATION = (
|
|
||||||
os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get(
|
|
||||||
"PASSWORD_VALIDATION_REGEX_PATTERN",
|
|
||||||
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}")
|
|
||||||
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(
|
|
||||||
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BYPASS_MODEL_ACCESS_CONTROL = (
|
BYPASS_MODEL_ACCESS_CONTROL = (
|
||||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -502,10 +493,7 @@ OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
|
||||||
# SCIM Configuration
|
# SCIM Configuration
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
ENABLE_SCIM = (
|
SCIM_ENABLED = os.environ.get("SCIM_ENABLED", "False").lower() == "true"
|
||||||
os.environ.get("ENABLE_SCIM", os.environ.get("SCIM_ENABLED", "False")).lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
|
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -539,10 +527,6 @@ if LICENSE_PUBLIC_KEY:
|
||||||
# MODELS
|
# MODELS
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
ENABLE_CUSTOM_MODEL_FALLBACK = (
|
|
||||||
os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
|
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
|
||||||
if MODELS_CACHE_TTL == "":
|
if MODELS_CACHE_TTL == "":
|
||||||
MODELS_CACHE_TTL = None
|
MODELS_CACHE_TTL = None
|
||||||
|
|
@ -557,11 +541,6 @@ else:
|
||||||
# CHAT
|
# CHAT
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION = (
|
|
||||||
os.environ.get("ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION", "False").lower()
|
|
||||||
== "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
|
||||||
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
|
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
|
||||||
)
|
)
|
||||||
|
|
@ -590,21 +569,6 @@ else:
|
||||||
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
|
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
|
||||||
|
|
||||||
|
|
||||||
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get(
|
|
||||||
"CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", ""
|
|
||||||
)
|
|
||||||
|
|
||||||
if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "":
|
|
||||||
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(
|
|
||||||
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
|
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBSOCKET SUPPORT
|
# WEBSOCKET SUPPORT
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -616,24 +580,6 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
||||||
|
|
||||||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||||
|
|
||||||
WEBSOCKET_REDIS_OPTIONS = os.environ.get("WEBSOCKET_REDIS_OPTIONS", "")
|
|
||||||
|
|
||||||
|
|
||||||
if WEBSOCKET_REDIS_OPTIONS == "":
|
|
||||||
if REDIS_SOCKET_CONNECT_TIMEOUT:
|
|
||||||
WEBSOCKET_REDIS_OPTIONS = {
|
|
||||||
"socket_connect_timeout": REDIS_SOCKET_CONNECT_TIMEOUT
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
log.debug("No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None")
|
|
||||||
WEBSOCKET_REDIS_OPTIONS = None
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
WEBSOCKET_REDIS_OPTIONS = json.loads(WEBSOCKET_REDIS_OPTIONS)
|
|
||||||
except Exception:
|
|
||||||
log.warning("Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None")
|
|
||||||
WEBSOCKET_REDIS_OPTIONS = None
|
|
||||||
|
|
||||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||||
WEBSOCKET_REDIS_CLUSTER = (
|
WEBSOCKET_REDIS_CLUSTER = (
|
||||||
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
|
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
|
||||||
|
|
@ -648,23 +594,6 @@ except ValueError:
|
||||||
|
|
||||||
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
||||||
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
||||||
WEBSOCKET_SERVER_LOGGING = (
|
|
||||||
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
WEBSOCKET_SERVER_ENGINEIO_LOGGING = (
|
|
||||||
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
|
|
||||||
)
|
|
||||||
WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20")
|
|
||||||
try:
|
|
||||||
WEBSOCKET_SERVER_PING_TIMEOUT = int(WEBSOCKET_SERVER_PING_TIMEOUT)
|
|
||||||
except ValueError:
|
|
||||||
WEBSOCKET_SERVER_PING_TIMEOUT = 20
|
|
||||||
|
|
||||||
WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get("WEBSOCKET_SERVER_PING_INTERVAL", "25")
|
|
||||||
try:
|
|
||||||
WEBSOCKET_SERVER_PING_INTERVAL = int(WEBSOCKET_SERVER_PING_INTERVAL)
|
|
||||||
except ValueError:
|
|
||||||
WEBSOCKET_SERVER_PING_INTERVAL = 25
|
|
||||||
|
|
||||||
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||||
|
|
@ -777,9 +706,7 @@ if OFFLINE_MODE:
|
||||||
# AUDIT LOGGING
|
# AUDIT LOGGING
|
||||||
####################################
|
####################################
|
||||||
# Where to store log file
|
# Where to store log file
|
||||||
# Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to
|
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||||
# provide the whole path, like: /app/audit.log
|
|
||||||
AUDIT_LOGS_FILE_PATH = os.getenv("AUDIT_LOGS_FILE_PATH", f"{DATA_DIR}/audit.log")
|
|
||||||
# Maximum size of a file before rotating into a new log file
|
# Maximum size of a file before rotating into a new log file
|
||||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ from open_webui.utils.plugin import (
|
||||||
from open_webui.utils.tools import get_tools
|
from open_webui.utils.tools import get_tools
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
from open_webui.env import GLOBAL_LOG_LEVEL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
|
|
@ -54,6 +54,7 @@ from open_webui.utils.payload import (
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from open_webui.env import (
|
||||||
OPEN_WEBUI_DIR,
|
OPEN_WEBUI_DIR,
|
||||||
DATABASE_URL,
|
DATABASE_URL,
|
||||||
DATABASE_SCHEMA,
|
DATABASE_SCHEMA,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
DATABASE_POOL_MAX_OVERFLOW,
|
DATABASE_POOL_MAX_OVERFLOW,
|
||||||
DATABASE_POOL_RECYCLE,
|
DATABASE_POOL_RECYCLE,
|
||||||
DATABASE_POOL_SIZE,
|
DATABASE_POOL_SIZE,
|
||||||
|
|
@ -24,6 +25,7 @@ from sqlalchemy.sql.type_api import _T
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||||
|
|
||||||
|
|
||||||
class JSONField(types.TypeDecorator):
|
class JSONField(types.TypeDecorator):
|
||||||
|
|
@ -90,6 +92,8 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||||
|
|
||||||
# Extract database path from SQLCipher URL
|
# Extract database path from SQLCipher URL
|
||||||
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
||||||
|
if db_path.startswith("/"):
|
||||||
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
# Create a custom creator function that uses sqlcipher3
|
# Create a custom creator function that uses sqlcipher3
|
||||||
def create_sqlcipher_connection():
|
def create_sqlcipher_connection():
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from peewee import *
|
from peewee import *
|
||||||
from peewee import InterfaceError as PeeWeeInterfaceError
|
from peewee import InterfaceError as PeeWeeInterfaceError
|
||||||
from peewee import PostgresqlDatabase
|
from peewee import PostgresqlDatabase
|
||||||
|
|
@ -9,6 +10,7 @@ from playhouse.db_url import connect, parse
|
||||||
from playhouse.shortcuts import ReconnectMixin
|
from playhouse.shortcuts import ReconnectMixin
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||||
|
|
||||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||||
|
|
@ -54,6 +56,8 @@ def register_connection(db_url):
|
||||||
# Parse the database path from SQLCipher URL
|
# Parse the database path from SQLCipher URL
|
||||||
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
||||||
db_path = db_url.replace("sqlite+sqlcipher://", "")
|
db_path = db_url.replace("sqlite+sqlcipher://", "")
|
||||||
|
if db_path.startswith("/"):
|
||||||
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
# Use Peewee's native SqlCipherDatabase with encryption
|
# Use Peewee's native SqlCipherDatabase with encryption
|
||||||
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
||||||
|
|
|
||||||
|
|
@ -61,11 +61,11 @@ from open_webui.utils import logger
|
||||||
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
||||||
from open_webui.utils.logger import start_logger
|
from open_webui.utils.logger import start_logger
|
||||||
from open_webui.socket.main import (
|
from open_webui.socket.main import (
|
||||||
MODELS,
|
|
||||||
app as socket_app,
|
app as socket_app,
|
||||||
periodic_usage_pool_cleanup,
|
periodic_usage_pool_cleanup,
|
||||||
get_event_emitter,
|
get_event_emitter,
|
||||||
get_models_in_use,
|
get_models_in_use,
|
||||||
|
get_active_user_ids,
|
||||||
)
|
)
|
||||||
from open_webui.routers import (
|
from open_webui.routers import (
|
||||||
audio,
|
audio,
|
||||||
|
|
@ -160,11 +160,9 @@ from open_webui.config import (
|
||||||
IMAGES_OPENAI_API_BASE_URL,
|
IMAGES_OPENAI_API_BASE_URL,
|
||||||
IMAGES_OPENAI_API_VERSION,
|
IMAGES_OPENAI_API_VERSION,
|
||||||
IMAGES_OPENAI_API_KEY,
|
IMAGES_OPENAI_API_KEY,
|
||||||
IMAGES_OPENAI_API_PARAMS,
|
|
||||||
IMAGES_GEMINI_API_BASE_URL,
|
IMAGES_GEMINI_API_BASE_URL,
|
||||||
IMAGES_GEMINI_API_KEY,
|
IMAGES_GEMINI_API_KEY,
|
||||||
IMAGES_GEMINI_ENDPOINT_METHOD,
|
IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||||
ENABLE_IMAGE_EDIT,
|
|
||||||
IMAGE_EDIT_ENGINE,
|
IMAGE_EDIT_ENGINE,
|
||||||
IMAGE_EDIT_MODEL,
|
IMAGE_EDIT_MODEL,
|
||||||
IMAGE_EDIT_SIZE,
|
IMAGE_EDIT_SIZE,
|
||||||
|
|
@ -208,7 +206,6 @@ from open_webui.config import (
|
||||||
FIRECRAWL_API_KEY,
|
FIRECRAWL_API_KEY,
|
||||||
WEB_LOADER_ENGINE,
|
WEB_LOADER_ENGINE,
|
||||||
WEB_LOADER_CONCURRENT_REQUESTS,
|
WEB_LOADER_CONCURRENT_REQUESTS,
|
||||||
WEB_LOADER_TIMEOUT,
|
|
||||||
WHISPER_MODEL,
|
WHISPER_MODEL,
|
||||||
WHISPER_VAD_FILTER,
|
WHISPER_VAD_FILTER,
|
||||||
WHISPER_LANGUAGE,
|
WHISPER_LANGUAGE,
|
||||||
|
|
@ -227,12 +224,10 @@ from open_webui.config import (
|
||||||
RAG_RERANKING_MODEL,
|
RAG_RERANKING_MODEL,
|
||||||
RAG_EXTERNAL_RERANKER_URL,
|
RAG_EXTERNAL_RERANKER_URL,
|
||||||
RAG_EXTERNAL_RERANKER_API_KEY,
|
RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
RAG_EXTERNAL_RERANKER_TIMEOUT,
|
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
RAG_EMBEDDING_ENGINE,
|
RAG_EMBEDDING_ENGINE,
|
||||||
RAG_EMBEDDING_BATCH_SIZE,
|
RAG_EMBEDDING_BATCH_SIZE,
|
||||||
ENABLE_ASYNC_EMBEDDING,
|
|
||||||
RAG_TOP_K,
|
RAG_TOP_K,
|
||||||
RAG_TOP_K_RERANKER,
|
RAG_TOP_K_RERANKER,
|
||||||
RAG_RELEVANCE_THRESHOLD,
|
RAG_RELEVANCE_THRESHOLD,
|
||||||
|
|
@ -265,18 +260,26 @@ from open_webui.config import (
|
||||||
MINERU_API_MODE,
|
MINERU_API_MODE,
|
||||||
MINERU_API_URL,
|
MINERU_API_URL,
|
||||||
MINERU_API_KEY,
|
MINERU_API_KEY,
|
||||||
MINERU_API_TIMEOUT,
|
|
||||||
MINERU_PARAMS,
|
MINERU_PARAMS,
|
||||||
DATALAB_MARKER_USE_LLM,
|
DATALAB_MARKER_USE_LLM,
|
||||||
EXTERNAL_DOCUMENT_LOADER_URL,
|
EXTERNAL_DOCUMENT_LOADER_URL,
|
||||||
EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||||
TIKA_SERVER_URL,
|
TIKA_SERVER_URL,
|
||||||
DOCLING_SERVER_URL,
|
DOCLING_SERVER_URL,
|
||||||
DOCLING_API_KEY,
|
|
||||||
DOCLING_PARAMS,
|
DOCLING_PARAMS,
|
||||||
|
DOCLING_DO_OCR,
|
||||||
|
DOCLING_FORCE_OCR,
|
||||||
|
DOCLING_OCR_ENGINE,
|
||||||
|
DOCLING_OCR_LANG,
|
||||||
|
DOCLING_PDF_BACKEND,
|
||||||
|
DOCLING_TABLE_MODE,
|
||||||
|
DOCLING_PIPELINE,
|
||||||
|
DOCLING_DO_PICTURE_DESCRIPTION,
|
||||||
|
DOCLING_PICTURE_DESCRIPTION_MODE,
|
||||||
|
DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
||||||
|
DOCLING_PICTURE_DESCRIPTION_API,
|
||||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
DOCUMENT_INTELLIGENCE_KEY,
|
DOCUMENT_INTELLIGENCE_KEY,
|
||||||
DOCUMENT_INTELLIGENCE_MODEL,
|
|
||||||
MISTRAL_OCR_API_BASE_URL,
|
MISTRAL_OCR_API_BASE_URL,
|
||||||
MISTRAL_OCR_API_KEY,
|
MISTRAL_OCR_API_KEY,
|
||||||
RAG_TEXT_SPLITTER,
|
RAG_TEXT_SPLITTER,
|
||||||
|
|
@ -300,7 +303,6 @@ from open_webui.config import (
|
||||||
SERPAPI_API_KEY,
|
SERPAPI_API_KEY,
|
||||||
SERPAPI_ENGINE,
|
SERPAPI_ENGINE,
|
||||||
SEARXNG_QUERY_URL,
|
SEARXNG_QUERY_URL,
|
||||||
SEARXNG_LANGUAGE,
|
|
||||||
YACY_QUERY_URL,
|
YACY_QUERY_URL,
|
||||||
YACY_USERNAME,
|
YACY_USERNAME,
|
||||||
YACY_PASSWORD,
|
YACY_PASSWORD,
|
||||||
|
|
@ -317,7 +319,6 @@ from open_webui.config import (
|
||||||
PERPLEXITY_API_KEY,
|
PERPLEXITY_API_KEY,
|
||||||
PERPLEXITY_MODEL,
|
PERPLEXITY_MODEL,
|
||||||
PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||||
PERPLEXITY_SEARCH_API_URL,
|
|
||||||
SOUGOU_API_SID,
|
SOUGOU_API_SID,
|
||||||
SOUGOU_API_SK,
|
SOUGOU_API_SK,
|
||||||
KAGI_SEARCH_API_KEY,
|
KAGI_SEARCH_API_KEY,
|
||||||
|
|
@ -335,7 +336,6 @@ from open_webui.config import (
|
||||||
ENABLE_ONEDRIVE_PERSONAL,
|
ENABLE_ONEDRIVE_PERSONAL,
|
||||||
ENABLE_ONEDRIVE_BUSINESS,
|
ENABLE_ONEDRIVE_BUSINESS,
|
||||||
ENABLE_RAG_HYBRID_SEARCH,
|
ENABLE_RAG_HYBRID_SEARCH,
|
||||||
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
|
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||||
ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||||
|
|
@ -354,10 +354,9 @@ from open_webui.config import (
|
||||||
JWT_EXPIRES_IN,
|
JWT_EXPIRES_IN,
|
||||||
ENABLE_SIGNUP,
|
ENABLE_SIGNUP,
|
||||||
ENABLE_LOGIN_FORM,
|
ENABLE_LOGIN_FORM,
|
||||||
ENABLE_API_KEYS,
|
ENABLE_API_KEY,
|
||||||
ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
|
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||||
API_KEYS_ALLOWED_ENDPOINTS,
|
API_KEY_ALLOWED_ENDPOINTS,
|
||||||
ENABLE_FOLDERS,
|
|
||||||
ENABLE_CHANNELS,
|
ENABLE_CHANNELS,
|
||||||
ENABLE_NOTES,
|
ENABLE_NOTES,
|
||||||
ENABLE_COMMUNITY_SHARING,
|
ENABLE_COMMUNITY_SHARING,
|
||||||
|
|
@ -367,12 +366,10 @@ from open_webui.config import (
|
||||||
BYPASS_ADMIN_ACCESS_CONTROL,
|
BYPASS_ADMIN_ACCESS_CONTROL,
|
||||||
USER_PERMISSIONS,
|
USER_PERMISSIONS,
|
||||||
DEFAULT_USER_ROLE,
|
DEFAULT_USER_ROLE,
|
||||||
DEFAULT_GROUP_ID,
|
|
||||||
PENDING_USER_OVERLAY_CONTENT,
|
PENDING_USER_OVERLAY_CONTENT,
|
||||||
PENDING_USER_OVERLAY_TITLE,
|
PENDING_USER_OVERLAY_TITLE,
|
||||||
DEFAULT_PROMPT_SUGGESTIONS,
|
DEFAULT_PROMPT_SUGGESTIONS,
|
||||||
DEFAULT_MODELS,
|
DEFAULT_MODELS,
|
||||||
DEFAULT_PINNED_MODELS,
|
|
||||||
DEFAULT_ARENA_MODEL,
|
DEFAULT_ARENA_MODEL,
|
||||||
MODEL_ORDER_LIST,
|
MODEL_ORDER_LIST,
|
||||||
EVALUATION_ARENA_MODELS,
|
EVALUATION_ARENA_MODELS,
|
||||||
|
|
@ -431,7 +428,6 @@ from open_webui.config import (
|
||||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
VOICE_MODE_PROMPT_TEMPLATE,
|
|
||||||
QUERY_GENERATION_PROMPT_TEMPLATE,
|
QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
||||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||||
|
|
@ -439,7 +435,6 @@ from open_webui.config import (
|
||||||
reset_config,
|
reset_config,
|
||||||
)
|
)
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
ENABLE_CUSTOM_MODEL_FALLBACK,
|
|
||||||
LICENSE_KEY,
|
LICENSE_KEY,
|
||||||
AUDIT_EXCLUDED_PATHS,
|
AUDIT_EXCLUDED_PATHS,
|
||||||
AUDIT_LOG_LEVEL,
|
AUDIT_LOG_LEVEL,
|
||||||
|
|
@ -452,8 +447,8 @@ from open_webui.env import (
|
||||||
GLOBAL_LOG_LEVEL,
|
GLOBAL_LOG_LEVEL,
|
||||||
MAX_BODY_LOG_SIZE,
|
MAX_BODY_LOG_SIZE,
|
||||||
SAFE_MODE,
|
SAFE_MODE,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
VERSION,
|
VERSION,
|
||||||
DEPLOYMENT_ID,
|
|
||||||
INSTANCE_ID,
|
INSTANCE_ID,
|
||||||
WEBUI_BUILD_HASH,
|
WEBUI_BUILD_HASH,
|
||||||
WEBUI_SECRET_KEY,
|
WEBUI_SECRET_KEY,
|
||||||
|
|
@ -464,7 +459,7 @@ from open_webui.env import (
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||||
# SCIM
|
# SCIM
|
||||||
ENABLE_SCIM,
|
SCIM_ENABLED,
|
||||||
SCIM_TOKEN,
|
SCIM_TOKEN,
|
||||||
ENABLE_COMPRESSION_MIDDLEWARE,
|
ENABLE_COMPRESSION_MIDDLEWARE,
|
||||||
ENABLE_WEBSOCKET_SUPPORT,
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
|
|
@ -475,7 +470,6 @@ from open_webui.env import (
|
||||||
EXTERNAL_PWA_MANIFEST_URL,
|
EXTERNAL_PWA_MANIFEST_URL,
|
||||||
AIOHTTP_CLIENT_SESSION_SSL,
|
AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
ENABLE_STAR_SESSIONS_MIDDLEWARE,
|
ENABLE_STAR_SESSIONS_MIDDLEWARE,
|
||||||
ENABLE_PUBLIC_ACTIVE_USERS_COUNT,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -533,6 +527,7 @@ if SAFE_MODE:
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
class SPAStaticFiles(StaticFiles):
|
class SPAStaticFiles(StaticFiles):
|
||||||
|
|
@ -720,7 +715,7 @@ app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||||
#
|
#
|
||||||
########################################
|
########################################
|
||||||
|
|
||||||
app.state.ENABLE_SCIM = ENABLE_SCIM
|
app.state.SCIM_ENABLED = SCIM_ENABLED
|
||||||
app.state.SCIM_TOKEN = SCIM_TOKEN
|
app.state.SCIM_TOKEN = SCIM_TOKEN
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
|
|
@ -742,11 +737,11 @@ app.state.config.WEBUI_URL = WEBUI_URL
|
||||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||||
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
||||||
|
|
||||||
app.state.config.ENABLE_API_KEYS = ENABLE_API_KEYS
|
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
|
||||||
app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = (
|
app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
|
||||||
ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS
|
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
|
||||||
)
|
)
|
||||||
app.state.config.API_KEYS_ALLOWED_ENDPOINTS = API_KEYS_ALLOWED_ENDPOINTS
|
app.state.config.API_KEY_ALLOWED_ENDPOINTS = API_KEY_ALLOWED_ENDPOINTS
|
||||||
|
|
||||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||||
|
|
||||||
|
|
@ -755,13 +750,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
|
||||||
|
|
||||||
|
|
||||||
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
||||||
app.state.config.DEFAULT_PINNED_MODELS = DEFAULT_PINNED_MODELS
|
|
||||||
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
|
|
||||||
|
|
||||||
|
|
||||||
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||||
app.state.config.DEFAULT_GROUP_ID = DEFAULT_GROUP_ID
|
|
||||||
|
|
||||||
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
|
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
|
||||||
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
|
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
|
||||||
|
|
@ -771,9 +761,9 @@ app.state.config.RESPONSE_WATERMARK = RESPONSE_WATERMARK
|
||||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||||
app.state.config.BANNERS = WEBUI_BANNERS
|
app.state.config.BANNERS = WEBUI_BANNERS
|
||||||
|
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
|
||||||
|
|
||||||
|
|
||||||
app.state.config.ENABLE_FOLDERS = ENABLE_FOLDERS
|
|
||||||
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
|
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
|
||||||
app.state.config.ENABLE_NOTES = ENABLE_NOTES
|
app.state.config.ENABLE_NOTES = ENABLE_NOTES
|
||||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||||
|
|
@ -849,9 +839,6 @@ app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = FILE_IMAGE_COMPRESSION_HEIGHT
|
||||||
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
||||||
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = (
|
|
||||||
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
|
|
||||||
)
|
|
||||||
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
|
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
|
||||||
|
|
||||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||||
|
|
@ -872,17 +859,25 @@ app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
|
||||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
|
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||||
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||||
app.state.config.DOCLING_API_KEY = DOCLING_API_KEY
|
|
||||||
app.state.config.DOCLING_PARAMS = DOCLING_PARAMS
|
app.state.config.DOCLING_PARAMS = DOCLING_PARAMS
|
||||||
|
app.state.config.DOCLING_DO_OCR = DOCLING_DO_OCR
|
||||||
|
app.state.config.DOCLING_FORCE_OCR = DOCLING_FORCE_OCR
|
||||||
|
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
|
||||||
|
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
|
||||||
|
app.state.config.DOCLING_PDF_BACKEND = DOCLING_PDF_BACKEND
|
||||||
|
app.state.config.DOCLING_TABLE_MODE = DOCLING_TABLE_MODE
|
||||||
|
app.state.config.DOCLING_PIPELINE = DOCLING_PIPELINE
|
||||||
|
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
|
||||||
|
app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE
|
||||||
|
app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||||
|
app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API
|
||||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||||
app.state.config.DOCUMENT_INTELLIGENCE_MODEL = DOCUMENT_INTELLIGENCE_MODEL
|
|
||||||
app.state.config.MISTRAL_OCR_API_BASE_URL = MISTRAL_OCR_API_BASE_URL
|
app.state.config.MISTRAL_OCR_API_BASE_URL = MISTRAL_OCR_API_BASE_URL
|
||||||
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
|
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
|
||||||
app.state.config.MINERU_API_MODE = MINERU_API_MODE
|
app.state.config.MINERU_API_MODE = MINERU_API_MODE
|
||||||
app.state.config.MINERU_API_URL = MINERU_API_URL
|
app.state.config.MINERU_API_URL = MINERU_API_URL
|
||||||
app.state.config.MINERU_API_KEY = MINERU_API_KEY
|
app.state.config.MINERU_API_KEY = MINERU_API_KEY
|
||||||
app.state.config.MINERU_API_TIMEOUT = MINERU_API_TIMEOUT
|
|
||||||
app.state.config.MINERU_PARAMS = MINERU_PARAMS
|
app.state.config.MINERU_PARAMS = MINERU_PARAMS
|
||||||
|
|
||||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||||
|
|
@ -894,13 +889,11 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||||
app.state.config.ENABLE_ASYNC_EMBEDDING = ENABLE_ASYNC_EMBEDDING
|
|
||||||
|
|
||||||
app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
|
app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
|
||||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_URL = RAG_EXTERNAL_RERANKER_URL
|
app.state.config.RAG_EXTERNAL_RERANKER_URL = RAG_EXTERNAL_RERANKER_URL
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = RAG_EXTERNAL_RERANKER_API_KEY
|
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = RAG_EXTERNAL_RERANKER_API_KEY
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_TIMEOUT = RAG_EXTERNAL_RERANKER_TIMEOUT
|
|
||||||
|
|
||||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
|
|
||||||
|
|
@ -928,7 +921,6 @@ app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = WEB_SEARCH_CONCURRENT_REQUESTS
|
||||||
|
|
||||||
app.state.config.WEB_LOADER_ENGINE = WEB_LOADER_ENGINE
|
app.state.config.WEB_LOADER_ENGINE = WEB_LOADER_ENGINE
|
||||||
app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = WEB_LOADER_CONCURRENT_REQUESTS
|
app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = WEB_LOADER_CONCURRENT_REQUESTS
|
||||||
app.state.config.WEB_LOADER_TIMEOUT = WEB_LOADER_TIMEOUT
|
|
||||||
|
|
||||||
app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
|
app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
|
||||||
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||||
|
|
@ -941,7 +933,6 @@ app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||||
|
|
||||||
app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = OLLAMA_CLOUD_WEB_SEARCH_API_KEY
|
app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = OLLAMA_CLOUD_WEB_SEARCH_API_KEY
|
||||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||||
app.state.config.SEARXNG_LANGUAGE = SEARXNG_LANGUAGE
|
|
||||||
app.state.config.YACY_QUERY_URL = YACY_QUERY_URL
|
app.state.config.YACY_QUERY_URL = YACY_QUERY_URL
|
||||||
app.state.config.YACY_USERNAME = YACY_USERNAME
|
app.state.config.YACY_USERNAME = YACY_USERNAME
|
||||||
app.state.config.YACY_PASSWORD = YACY_PASSWORD
|
app.state.config.YACY_PASSWORD = YACY_PASSWORD
|
||||||
|
|
@ -967,7 +958,6 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||||
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||||
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
|
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
|
||||||
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
|
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
|
||||||
app.state.config.PERPLEXITY_SEARCH_API_URL = PERPLEXITY_SEARCH_API_URL
|
|
||||||
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
||||||
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
||||||
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
|
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
|
||||||
|
|
@ -992,7 +982,9 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.state.ef = get_ef(
|
app.state.ef = get_ef(
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||||
|
|
@ -1003,7 +995,7 @@ try:
|
||||||
app.state.config.RAG_RERANKING_MODEL,
|
app.state.config.RAG_RERANKING_MODEL,
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_TIMEOUT,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
app.state.rf = None
|
app.state.rf = None
|
||||||
|
|
@ -1040,7 +1032,6 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
enable_async=app.state.config.ENABLE_ASYNC_EMBEDDING,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.RERANKING_FUNCTION = get_reranking_function(
|
app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||||
|
|
@ -1096,7 +1087,6 @@ app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
||||||
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||||
app.state.config.IMAGES_OPENAI_API_VERSION = IMAGES_OPENAI_API_VERSION
|
app.state.config.IMAGES_OPENAI_API_VERSION = IMAGES_OPENAI_API_VERSION
|
||||||
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||||
app.state.config.IMAGES_OPENAI_API_PARAMS = IMAGES_OPENAI_API_PARAMS
|
|
||||||
|
|
||||||
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
||||||
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
||||||
|
|
@ -1112,7 +1102,6 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
||||||
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
||||||
|
|
||||||
|
|
||||||
app.state.config.ENABLE_IMAGE_EDIT = ENABLE_IMAGE_EDIT
|
|
||||||
app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
|
app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
|
||||||
app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
|
app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
|
||||||
app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
|
app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
|
||||||
|
|
@ -1217,7 +1206,6 @@ app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
|
||||||
app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
||||||
)
|
)
|
||||||
app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE
|
|
||||||
|
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
|
|
@ -1226,11 +1214,7 @@ app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE
|
||||||
#
|
#
|
||||||
########################################
|
########################################
|
||||||
|
|
||||||
app.state.MODELS = MODELS
|
app.state.MODELS = {}
|
||||||
|
|
||||||
# Add the middleware to the app
|
|
||||||
if ENABLE_COMPRESSION_MIDDLEWARE:
|
|
||||||
app.add_middleware(CompressMiddleware)
|
|
||||||
|
|
||||||
|
|
||||||
class RedirectMiddleware(BaseHTTPMiddleware):
|
class RedirectMiddleware(BaseHTTPMiddleware):
|
||||||
|
|
@ -1273,53 +1257,14 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# Add the middleware to the app
|
||||||
|
if ENABLE_COMPRESSION_MIDDLEWARE:
|
||||||
|
app.add_middleware(CompressMiddleware)
|
||||||
|
|
||||||
app.add_middleware(RedirectMiddleware)
|
app.add_middleware(RedirectMiddleware)
|
||||||
app.add_middleware(SecurityHeadersMiddleware)
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
|
||||||
|
|
||||||
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
token = None
|
|
||||||
|
|
||||||
if auth_header:
|
|
||||||
scheme, token = auth_header.split(" ")
|
|
||||||
|
|
||||||
# Only apply restrictions if an sk- API key is used
|
|
||||||
if token and token.startswith("sk-"):
|
|
||||||
# Check if restrictions are enabled
|
|
||||||
if request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
|
|
||||||
allowed_paths = [
|
|
||||||
path.strip()
|
|
||||||
for path in str(
|
|
||||||
request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS
|
|
||||||
).split(",")
|
|
||||||
if path.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
request_path = request.url.path
|
|
||||||
|
|
||||||
# Match exact path or prefix path
|
|
||||||
is_allowed = any(
|
|
||||||
request_path == allowed or request_path.startswith(allowed + "/")
|
|
||||||
for allowed in allowed_paths
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_allowed:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
content={
|
|
||||||
"detail": "API key not allowed to access this endpoint."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(APIKeyRestrictionMiddleware)
|
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def commit_session_after_request(request: Request, call_next):
|
async def commit_session_after_request(request: Request, call_next):
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
@ -1335,7 +1280,7 @@ async def check_url(request: Request, call_next):
|
||||||
request.headers.get("Authorization")
|
request.headers.get("Authorization")
|
||||||
)
|
)
|
||||||
|
|
||||||
request.state.enable_api_keys = app.state.config.ENABLE_API_KEYS
|
request.state.enable_api_key = app.state.config.ENABLE_API_KEY
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
process_time = int(time.time()) - start_time
|
process_time = int(time.time()) - start_time
|
||||||
response.headers["X-Process-Time"] = str(process_time)
|
response.headers["X-Process-Time"] = str(process_time)
|
||||||
|
|
@ -1410,7 +1355,7 @@ app.include_router(
|
||||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||||
|
|
||||||
# SCIM 2.0 API for identity management
|
# SCIM 2.0 API for identity management
|
||||||
if ENABLE_SCIM:
|
if SCIM_ENABLED:
|
||||||
app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"])
|
app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1447,10 +1392,6 @@ async def get_models(
|
||||||
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
|
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Remove profile image URL to reduce payload size
|
|
||||||
if model.get("info", {}).get("meta", {}).get("profile_image_url"):
|
|
||||||
model["info"]["meta"].pop("profile_image_url", None)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_tags = [
|
model_tags = [
|
||||||
tag.get("name")
|
tag.get("name")
|
||||||
|
|
@ -1540,7 +1481,6 @@ async def chat_completion(
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
try:
|
try:
|
||||||
model_info = None
|
|
||||||
if not model_item.get("direct", False):
|
if not model_item.get("direct", False):
|
||||||
if model_id not in request.app.state.MODELS:
|
if model_id not in request.app.state.MODELS:
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
@ -1558,6 +1498,7 @@ async def chat_completion(
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
model = model_item
|
model = model_item
|
||||||
|
model_info = None
|
||||||
|
|
||||||
request.state.direct = True
|
request.state.direct = True
|
||||||
request.state.model = model
|
request.state.model = model
|
||||||
|
|
@ -1566,26 +1507,6 @@ async def chat_completion(
|
||||||
model_info.params.model_dump() if model_info and model_info.params else {}
|
model_info.params.model_dump() if model_info and model_info.params else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check base model existence for custom models
|
|
||||||
if model_info_params.get("base_model_id"):
|
|
||||||
base_model_id = model_info_params.get("base_model_id")
|
|
||||||
if base_model_id not in request.app.state.MODELS:
|
|
||||||
if ENABLE_CUSTOM_MODEL_FALLBACK:
|
|
||||||
default_models = (
|
|
||||||
request.app.state.config.DEFAULT_MODELS or ""
|
|
||||||
).split(",")
|
|
||||||
|
|
||||||
fallback_model_id = (
|
|
||||||
default_models[0].strip() if default_models[0] else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if fallback_model_id:
|
|
||||||
request.base_model_id = fallback_model_id
|
|
||||||
else:
|
|
||||||
raise Exception("Model not found")
|
|
||||||
else:
|
|
||||||
raise Exception("Model not found")
|
|
||||||
|
|
||||||
# Chat Params
|
# Chat Params
|
||||||
stream_delta_chunk_size = form_data.get("params", {}).get(
|
stream_delta_chunk_size = form_data.get("params", {}).get(
|
||||||
"stream_delta_chunk_size"
|
"stream_delta_chunk_size"
|
||||||
|
|
@ -1593,9 +1514,6 @@ async def chat_completion(
|
||||||
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
|
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
|
||||||
|
|
||||||
# Model Params
|
# Model Params
|
||||||
if model_info_params.get("stream_response") is not None:
|
|
||||||
form_data["stream"] = model_info_params.get("stream_response")
|
|
||||||
|
|
||||||
if model_info_params.get("stream_delta_chunk_size"):
|
if model_info_params.get("stream_delta_chunk_size"):
|
||||||
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
||||||
|
|
||||||
|
|
@ -1606,8 +1524,6 @@ async def chat_completion(
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
"chat_id": form_data.pop("chat_id", None),
|
"chat_id": form_data.pop("chat_id", None),
|
||||||
"message_id": form_data.pop("id", None),
|
"message_id": form_data.pop("id", None),
|
||||||
"parent_message": form_data.pop("parent_message", None),
|
|
||||||
"parent_message_id": form_data.pop("parent_id", None),
|
|
||||||
"session_id": form_data.pop("session_id", None),
|
"session_id": form_data.pop("session_id", None),
|
||||||
"filter_ids": form_data.pop("filter_ids", []),
|
"filter_ids": form_data.pop("filter_ids", []),
|
||||||
"tool_ids": form_data.get("tool_ids", None),
|
"tool_ids": form_data.get("tool_ids", None),
|
||||||
|
|
@ -1631,38 +1547,15 @@ async def chat_completion(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata.get("chat_id") and user:
|
if metadata.get("chat_id") and (user and user.role != "admin"):
|
||||||
if not metadata["chat_id"].startswith(
|
if not metadata["chat_id"].startswith("local:"):
|
||||||
"local:"
|
|
||||||
): # temporary chats are not stored
|
|
||||||
|
|
||||||
# Verify chat ownership
|
|
||||||
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
||||||
if chat is None and user.role != "admin": # admins can access any chat
|
if chat is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(),
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert chat files from parent message if any
|
|
||||||
parent_message = metadata.get("parent_message", {})
|
|
||||||
parent_message_files = parent_message.get("files", [])
|
|
||||||
if parent_message_files:
|
|
||||||
try:
|
|
||||||
Chats.insert_chat_files(
|
|
||||||
metadata["chat_id"],
|
|
||||||
parent_message.get("id"),
|
|
||||||
[
|
|
||||||
file_item.get("id")
|
|
||||||
for file_item in parent_message_files
|
|
||||||
if file_item.get("type") == "file"
|
|
||||||
],
|
|
||||||
user.id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.debug(f"Error inserting chat files: {e}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
request.state.metadata = metadata
|
request.state.metadata = metadata
|
||||||
form_data["metadata"] = metadata
|
form_data["metadata"] = metadata
|
||||||
|
|
||||||
|
|
@ -1687,7 +1580,6 @@ async def chat_completion(
|
||||||
metadata["chat_id"],
|
metadata["chat_id"],
|
||||||
metadata["message_id"],
|
metadata["message_id"],
|
||||||
{
|
{
|
||||||
"parentId": metadata.get("parent_message_id", None),
|
|
||||||
"model": model_id,
|
"model": model_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -1720,7 +1612,6 @@ async def chat_completion(
|
||||||
metadata["chat_id"],
|
metadata["chat_id"],
|
||||||
metadata["message_id"],
|
metadata["message_id"],
|
||||||
{
|
{
|
||||||
"parentId": metadata.get("parent_message_id", None),
|
|
||||||
"error": {"content": str(e)},
|
"error": {"content": str(e)},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -1892,16 +1783,14 @@ async def get_app_config(request: Request):
|
||||||
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||||
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
||||||
"enable_ldap": app.state.config.ENABLE_LDAP,
|
"enable_ldap": app.state.config.ENABLE_LDAP,
|
||||||
"enable_api_keys": app.state.config.ENABLE_API_KEYS,
|
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
||||||
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
||||||
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
||||||
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
||||||
"enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK,
|
"enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK,
|
||||||
"enable_public_active_users_count": ENABLE_PUBLIC_ACTIVE_USERS_COUNT,
|
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||||
"enable_folders": app.state.config.ENABLE_FOLDERS,
|
|
||||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||||
"enable_notes": app.state.config.ENABLE_NOTES,
|
"enable_notes": app.state.config.ENABLE_NOTES,
|
||||||
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
|
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
|
||||||
|
|
@ -1932,7 +1821,6 @@ async def get_app_config(request: Request):
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"default_models": app.state.config.DEFAULT_MODELS,
|
"default_models": app.state.config.DEFAULT_MODELS,
|
||||||
"default_pinned_models": app.state.config.DEFAULT_PINNED_MODELS,
|
|
||||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||||
"user_count": user_count,
|
"user_count": user_count,
|
||||||
"code": {
|
"code": {
|
||||||
|
|
@ -2034,7 +1922,6 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
||||||
async def get_app_version():
|
async def get_app_version():
|
||||||
return {
|
return {
|
||||||
"version": VERSION,
|
"version": VERSION,
|
||||||
"deployment_id": DEPLOYMENT_ID,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2074,19 +1961,7 @@ async def get_current_usage(user=Depends(get_verified_user)):
|
||||||
This is an experimental endpoint and subject to change.
|
This is an experimental endpoint and subject to change.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# If public visibility is disabled, only allow admins to access this endpoint
|
return {"model_ids": get_models_in_use(), "user_ids": get_active_user_ids()}
|
||||||
if not ENABLE_PUBLIC_ACTIVE_USERS_COUNT and user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Access denied. Only administrators can view usage statistics.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"model_ids": get_models_in_use(),
|
|
||||||
"user_count": Users.get_active_user_count(),
|
|
||||||
}
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting usage statistics: {e}")
|
log.error(f"Error getting usage statistics: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Internal Server Error")
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
||||||
|
|
@ -2149,7 +2024,7 @@ except Exception as e:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def register_client(request, client_id: str) -> bool:
|
async def register_client(self, request, client_id: str) -> bool:
|
||||||
server_type, server_id = client_id.split(":", 1)
|
server_type, server_id = client_id.split(":", 1)
|
||||||
|
|
||||||
connection = None
|
connection = None
|
||||||
|
|
|
||||||
|
|
@ -1,103 +0,0 @@
|
||||||
"""Update messages and channel member table
|
|
||||||
|
|
||||||
Revision ID: 2f1211949ecc
|
|
||||||
Revises: 37f288994c47
|
|
||||||
Create Date: 2025-11-27 03:07:56.200231
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import open_webui.internal.db
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "2f1211949ecc"
|
|
||||||
down_revision: Union[str, None] = "37f288994c47"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# New columns to be added to channel_member table
|
|
||||||
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True))
|
|
||||||
op.add_column(
|
|
||||||
"channel_member",
|
|
||||||
sa.Column(
|
|
||||||
"is_active",
|
|
||||||
sa.Boolean(),
|
|
||||||
nullable=False,
|
|
||||||
default=True,
|
|
||||||
server_default=sa.sql.expression.true(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column(
|
|
||||||
"channel_member",
|
|
||||||
sa.Column(
|
|
||||||
"is_channel_muted",
|
|
||||||
sa.Boolean(),
|
|
||||||
nullable=False,
|
|
||||||
default=False,
|
|
||||||
server_default=sa.sql.expression.false(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
op.add_column(
|
|
||||||
"channel_member",
|
|
||||||
sa.Column(
|
|
||||||
"is_channel_pinned",
|
|
||||||
sa.Boolean(),
|
|
||||||
nullable=False,
|
|
||||||
default=False,
|
|
||||||
server_default=sa.sql.expression.false(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True))
|
|
||||||
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True))
|
|
||||||
|
|
||||||
op.add_column(
|
|
||||||
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False)
|
|
||||||
)
|
|
||||||
op.add_column(
|
|
||||||
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column(
|
|
||||||
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column(
|
|
||||||
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# New columns to be added to message table
|
|
||||||
op.add_column(
|
|
||||||
"message",
|
|
||||||
sa.Column(
|
|
||||||
"is_pinned",
|
|
||||||
sa.Boolean(),
|
|
||||||
nullable=False,
|
|
||||||
default=False,
|
|
||||||
server_default=sa.sql.expression.false(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True))
|
|
||||||
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True))
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_column("channel_member", "updated_at")
|
|
||||||
op.drop_column("channel_member", "last_read_at")
|
|
||||||
|
|
||||||
op.drop_column("channel_member", "meta")
|
|
||||||
op.drop_column("channel_member", "data")
|
|
||||||
|
|
||||||
op.drop_column("channel_member", "is_channel_pinned")
|
|
||||||
op.drop_column("channel_member", "is_channel_muted")
|
|
||||||
|
|
||||||
op.drop_column("message", "pinned_by")
|
|
||||||
op.drop_column("message", "pinned_at")
|
|
||||||
op.drop_column("message", "is_pinned")
|
|
||||||
|
|
@ -1,146 +0,0 @@
|
||||||
"""add_group_member_table
|
|
||||||
|
|
||||||
Revision ID: 37f288994c47
|
|
||||||
Revises: a5c220713937
|
|
||||||
Create Date: 2025-11-17 03:45:25.123939
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "37f288994c47"
|
|
||||||
down_revision: Union[str, None] = "a5c220713937"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# 1. Create new table
|
|
||||||
op.create_table(
|
|
||||||
"group_member",
|
|
||||||
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"group_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("group.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"user_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
|
||||||
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
|
|
||||||
)
|
|
||||||
|
|
||||||
connection = op.get_bind()
|
|
||||||
|
|
||||||
# 2. Read existing group with user_ids JSON column
|
|
||||||
group_table = sa.Table(
|
|
||||||
"group",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
|
|
||||||
)
|
|
||||||
|
|
||||||
results = connection.execute(
|
|
||||||
sa.select(group_table.c.id, group_table.c.user_ids)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
print(results)
|
|
||||||
|
|
||||||
# 3. Insert members into group_member table
|
|
||||||
gm_table = sa.Table(
|
|
||||||
"group_member",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
sa.Column("group_id", sa.Text()),
|
|
||||||
sa.Column("user_id", sa.Text()),
|
|
||||||
sa.Column("created_at", sa.BigInteger()),
|
|
||||||
sa.Column("updated_at", sa.BigInteger()),
|
|
||||||
)
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
for group_id, user_ids in results:
|
|
||||||
if not user_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(user_ids, str):
|
|
||||||
try:
|
|
||||||
user_ids = json.loads(user_ids)
|
|
||||||
except Exception:
|
|
||||||
continue # skip invalid JSON
|
|
||||||
|
|
||||||
if not isinstance(user_ids, list):
|
|
||||||
continue
|
|
||||||
|
|
||||||
rows = [
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"group_id": group_id,
|
|
||||||
"user_id": uid,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
}
|
|
||||||
for uid in user_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
if rows:
|
|
||||||
connection.execute(gm_table.insert(), rows)
|
|
||||||
|
|
||||||
# 4. Optionally drop the old column
|
|
||||||
with op.batch_alter_table("group") as batch:
|
|
||||||
batch.drop_column("user_ids")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# Reverse: restore user_ids column
|
|
||||||
with op.batch_alter_table("group") as batch:
|
|
||||||
batch.add_column(sa.Column("user_ids", sa.JSON()))
|
|
||||||
|
|
||||||
connection = op.get_bind()
|
|
||||||
gm_table = sa.Table(
|
|
||||||
"group_member",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("group_id", sa.Text()),
|
|
||||||
sa.Column("user_id", sa.Text()),
|
|
||||||
sa.Column("created_at", sa.BigInteger()),
|
|
||||||
sa.Column("updated_at", sa.BigInteger()),
|
|
||||||
)
|
|
||||||
|
|
||||||
group_table = sa.Table(
|
|
||||||
"group",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
sa.Column("user_ids", sa.JSON()),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build JSON arrays again
|
|
||||||
results = connection.execute(sa.select(group_table.c.id)).fetchall()
|
|
||||||
|
|
||||||
for (group_id,) in results:
|
|
||||||
members = connection.execute(
|
|
||||||
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
member_ids = [m[0] for m in members]
|
|
||||||
|
|
||||||
connection.execute(
|
|
||||||
group_table.update()
|
|
||||||
.where(group_table.c.id == group_id)
|
|
||||||
.values(user_ids=member_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Drop the new table
|
|
||||||
op.drop_table("group_member")
|
|
||||||
|
|
@ -20,46 +20,18 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
|
|
||||||
inspector = sa.inspect(op.get_bind())
|
|
||||||
columns = inspector.get_columns("user")
|
|
||||||
|
|
||||||
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"]
|
|
||||||
id_column = next((col for col in columns if col["name"] == "id"), None)
|
|
||||||
|
|
||||||
if id_column and not id_column.get("unique", False):
|
|
||||||
unique_constraints = inspector.get_unique_constraints("user")
|
|
||||||
unique_columns = {tuple(u["column_names"]) for u in unique_constraints}
|
|
||||||
|
|
||||||
with op.batch_alter_table("user") as batch_op:
|
|
||||||
# If primary key is wrong, drop it
|
|
||||||
if pk_columns and pk_columns != ["id"]:
|
|
||||||
batch_op.drop_constraint(
|
|
||||||
inspector.get_pk_constraint("user")["name"], type_="primary"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add unique constraint if missing
|
|
||||||
if ("id",) not in unique_columns:
|
|
||||||
batch_op.create_unique_constraint("uq_user_id", ["id"])
|
|
||||||
|
|
||||||
# Re-create correct primary key
|
|
||||||
batch_op.create_primary_key("pk_user_id", ["id"])
|
|
||||||
|
|
||||||
# Create oauth_session table
|
# Create oauth_session table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"oauth_session",
|
"oauth_session",
|
||||||
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True),
|
sa.Column("id", sa.Text(), nullable=False),
|
||||||
sa.Column(
|
sa.Column("user_id", sa.Text(), nullable=False),
|
||||||
"user_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("provider", sa.Text(), nullable=False),
|
sa.Column("provider", sa.Text(), nullable=False),
|
||||||
sa.Column("token", sa.Text(), nullable=False),
|
sa.Column("token", sa.Text(), nullable=False),
|
||||||
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create indexes for better performance
|
# Create indexes for better performance
|
||||||
|
|
|
||||||
|
|
@ -1,169 +0,0 @@
|
||||||
"""Add knowledge_file table
|
|
||||||
|
|
||||||
Revision ID: 3e0e00844bb0
|
|
||||||
Revises: 90ef40d4714e
|
|
||||||
Create Date: 2025-12-02 06:54:19.401334
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy import inspect
|
|
||||||
import open_webui.internal.db
|
|
||||||
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "3e0e00844bb0"
|
|
||||||
down_revision: Union[str, None] = "90ef40d4714e"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"knowledge_file",
|
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"knowledge_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("knowledge.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"file_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
|
||||||
# indexes
|
|
||||||
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"),
|
|
||||||
sa.Index("ix_knowledge_file_file_id", "file_id"),
|
|
||||||
sa.Index("ix_knowledge_file_user_id", "user_id"),
|
|
||||||
# unique constraints
|
|
||||||
sa.UniqueConstraint(
|
|
||||||
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
|
||||||
), # prevent duplicate entries
|
|
||||||
)
|
|
||||||
|
|
||||||
connection = op.get_bind()
|
|
||||||
|
|
||||||
# 2. Read existing group with user_ids JSON column
|
|
||||||
knowledge_table = sa.Table(
|
|
||||||
"knowledge",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
sa.Column("user_id", sa.Text()),
|
|
||||||
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG
|
|
||||||
)
|
|
||||||
|
|
||||||
results = connection.execute(
|
|
||||||
sa.select(
|
|
||||||
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
|
|
||||||
)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
# 3. Insert members into group_member table
|
|
||||||
kf_table = sa.Table(
|
|
||||||
"knowledge_file",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
sa.Column("user_id", sa.Text()),
|
|
||||||
sa.Column("knowledge_id", sa.Text()),
|
|
||||||
sa.Column("file_id", sa.Text()),
|
|
||||||
sa.Column("created_at", sa.BigInteger()),
|
|
||||||
sa.Column("updated_at", sa.BigInteger()),
|
|
||||||
)
|
|
||||||
|
|
||||||
file_table = sa.Table(
|
|
||||||
"file",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
)
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
for knowledge_id, user_id, data in results:
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(data, str):
|
|
||||||
try:
|
|
||||||
data = json.loads(data)
|
|
||||||
except Exception:
|
|
||||||
continue # skip invalid JSON
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_ids = data.get("file_ids", [])
|
|
||||||
|
|
||||||
for file_id in file_ids:
|
|
||||||
file_exists = connection.execute(
|
|
||||||
sa.select(file_table.c.id).where(file_table.c.id == file_id)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if not file_exists:
|
|
||||||
continue # skip non-existing files
|
|
||||||
|
|
||||||
row = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"user_id": user_id,
|
|
||||||
"knowledge_id": knowledge_id,
|
|
||||||
"file_id": file_id,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
}
|
|
||||||
connection.execute(kf_table.insert().values(**row))
|
|
||||||
|
|
||||||
with op.batch_alter_table("knowledge") as batch:
|
|
||||||
batch.drop_column("data")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# 1. Add back the old data column
|
|
||||||
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True))
|
|
||||||
|
|
||||||
connection = op.get_bind()
|
|
||||||
|
|
||||||
# 2. Read knowledge_file entries and reconstruct data JSON
|
|
||||||
knowledge_table = sa.Table(
|
|
||||||
"knowledge",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
sa.Column("data", sa.JSON()),
|
|
||||||
)
|
|
||||||
|
|
||||||
kf_table = sa.Table(
|
|
||||||
"knowledge_file",
|
|
||||||
sa.MetaData(),
|
|
||||||
sa.Column("id", sa.Text()),
|
|
||||||
sa.Column("knowledge_id", sa.Text()),
|
|
||||||
sa.Column("file_id", sa.Text()),
|
|
||||||
)
|
|
||||||
|
|
||||||
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
|
|
||||||
|
|
||||||
for (knowledge_id,) in results:
|
|
||||||
file_ids = connection.execute(
|
|
||||||
sa.select(kf_table.c.file_id).where(kf_table.c.knowledge_id == knowledge_id)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
file_ids_list = [fid for (fid,) in file_ids]
|
|
||||||
|
|
||||||
data_json = {"file_ids": file_ids_list}
|
|
||||||
|
|
||||||
connection.execute(
|
|
||||||
knowledge_table.update()
|
|
||||||
.where(knowledge_table.c.id == knowledge_id)
|
|
||||||
.values(data=data_json)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Drop the knowledge_file table
|
|
||||||
op.drop_table("knowledge_file")
|
|
||||||
|
|
@ -1,54 +0,0 @@
|
||||||
"""Add channel file table
|
|
||||||
|
|
||||||
Revision ID: 6283dc0e4d8d
|
|
||||||
Revises: 3e0e00844bb0
|
|
||||||
Create Date: 2025-12-10 15:11:39.424601
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import open_webui.internal.db
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "6283dc0e4d8d"
|
|
||||||
down_revision: Union[str, None] = "3e0e00844bb0"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"channel_file",
|
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"channel_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("channel.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"file_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
|
||||||
# indexes
|
|
||||||
sa.Index("ix_channel_file_channel_id", "channel_id"),
|
|
||||||
sa.Index("ix_channel_file_file_id", "file_id"),
|
|
||||||
sa.Index("ix_channel_file_user_id", "user_id"),
|
|
||||||
# unique constraints
|
|
||||||
sa.UniqueConstraint(
|
|
||||||
"channel_id", "file_id", name="uq_channel_file_channel_file"
|
|
||||||
), # prevent duplicate entries
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_table("channel_file")
|
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
"""Update channel file and knowledge table
|
|
||||||
|
|
||||||
Revision ID: 81cc2ce44d79
|
|
||||||
Revises: 6283dc0e4d8d
|
|
||||||
Create Date: 2025-12-10 16:07:58.001282
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import open_webui.internal.db
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "81cc2ce44d79"
|
|
||||||
down_revision: Union[str, None] = "6283dc0e4d8d"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Add message_id column to channel_file table
|
|
||||||
with op.batch_alter_table("channel_file", schema=None) as batch_op:
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column(
|
|
||||||
"message_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey(
|
|
||||||
"message.id", ondelete="CASCADE", name="fk_channel_file_message_id"
|
|
||||||
),
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add data column to knowledge table
|
|
||||||
with op.batch_alter_table("knowledge", schema=None) as batch_op:
|
|
||||||
batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True))
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Remove message_id column from channel_file table
|
|
||||||
with op.batch_alter_table("channel_file", schema=None) as batch_op:
|
|
||||||
batch_op.drop_column("message_id")
|
|
||||||
|
|
||||||
# Remove data column from knowledge table
|
|
||||||
with op.batch_alter_table("knowledge", schema=None) as batch_op:
|
|
||||||
batch_op.drop_column("data")
|
|
||||||
|
|
@ -1,81 +0,0 @@
|
||||||
"""Update channel and channel members table
|
|
||||||
|
|
||||||
Revision ID: 90ef40d4714e
|
|
||||||
Revises: b10670c03dd5
|
|
||||||
Create Date: 2025-11-30 06:33:38.790341
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import open_webui.internal.db
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "90ef40d4714e"
|
|
||||||
down_revision: Union[str, None] = "b10670c03dd5"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Update 'channel' table
|
|
||||||
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True))
|
|
||||||
|
|
||||||
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True))
|
|
||||||
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True))
|
|
||||||
|
|
||||||
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True))
|
|
||||||
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True))
|
|
||||||
|
|
||||||
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True))
|
|
||||||
|
|
||||||
# Update 'channel_member' table
|
|
||||||
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True))
|
|
||||||
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True))
|
|
||||||
op.add_column(
|
|
||||||
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create 'channel_webhook' table
|
|
||||||
op.create_table(
|
|
||||||
"channel_webhook",
|
|
||||||
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"channel_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("channel.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("name", sa.Text(), nullable=False),
|
|
||||||
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
|
||||||
sa.Column("token", sa.Text(), nullable=False),
|
|
||||||
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Downgrade 'channel' table
|
|
||||||
op.drop_column("channel", "is_private")
|
|
||||||
op.drop_column("channel", "archived_at")
|
|
||||||
op.drop_column("channel", "archived_by")
|
|
||||||
op.drop_column("channel", "deleted_at")
|
|
||||||
op.drop_column("channel", "deleted_by")
|
|
||||||
op.drop_column("channel", "updated_by")
|
|
||||||
|
|
||||||
# Downgrade 'channel_member' table
|
|
||||||
op.drop_column("channel_member", "role")
|
|
||||||
op.drop_column("channel_member", "invited_by")
|
|
||||||
op.drop_column("channel_member", "invited_at")
|
|
||||||
|
|
||||||
# Drop 'channel_webhook' table
|
|
||||||
op.drop_table("channel_webhook")
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
@ -1,251 +0,0 @@
|
||||||
"""Update user table
|
|
||||||
|
|
||||||
Revision ID: b10670c03dd5
|
|
||||||
Revises: 2f1211949ecc
|
|
||||||
Create Date: 2025-11-28 04:55:31.737538
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
import open_webui.internal.db
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "b10670c03dd5"
|
|
||||||
down_revision: Union[str, None] = "2f1211949ecc"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
|
|
||||||
"""
|
|
||||||
SQLite requires manual removal of any indexes referencing a column
|
|
||||||
before ALTER TABLE ... DROP COLUMN can succeed.
|
|
||||||
"""
|
|
||||||
indexes = conn.execute(sa.text(f"PRAGMA index_list('{table_name}')")).fetchall()
|
|
||||||
|
|
||||||
for idx in indexes:
|
|
||||||
index_name = idx[1] # index name
|
|
||||||
# Get indexed columns
|
|
||||||
idx_info = conn.execute(
|
|
||||||
sa.text(f"PRAGMA index_info('{index_name}')")
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
indexed_cols = [row[2] for row in idx_info] # col names
|
|
||||||
if column_name in indexed_cols:
|
|
||||||
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_column_to_json(table: str, column: str):
|
|
||||||
conn = op.get_bind()
|
|
||||||
dialect = conn.dialect.name
|
|
||||||
|
|
||||||
# SQLite cannot ALTER COLUMN → must recreate column
|
|
||||||
if dialect == "sqlite":
|
|
||||||
# 1. Add temporary column
|
|
||||||
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
|
|
||||||
|
|
||||||
# 2. Load old data
|
|
||||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
uid, raw = row
|
|
||||||
if raw is None:
|
|
||||||
parsed = None
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
parsed = json.loads(raw)
|
|
||||||
except Exception:
|
|
||||||
parsed = None # fallback safe behavior
|
|
||||||
|
|
||||||
conn.execute(
|
|
||||||
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
|
|
||||||
{"val": json.dumps(parsed) if parsed else None, "id": uid},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Drop old TEXT column
|
|
||||||
op.drop_column(table, column)
|
|
||||||
|
|
||||||
# 4. Rename new JSON column → original name
|
|
||||||
op.alter_column(table, f"{column}_json", new_column_name=column)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# PostgreSQL supports direct CAST
|
|
||||||
op.alter_column(
|
|
||||||
table,
|
|
||||||
column,
|
|
||||||
type_=sa.JSON(),
|
|
||||||
postgresql_using=f"{column}::json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_column_to_text(table: str, column: str):
|
|
||||||
conn = op.get_bind()
|
|
||||||
dialect = conn.dialect.name
|
|
||||||
|
|
||||||
if dialect == "sqlite":
|
|
||||||
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
|
|
||||||
|
|
||||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
|
||||||
|
|
||||||
for uid, raw in rows:
|
|
||||||
conn.execute(
|
|
||||||
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
|
|
||||||
{"val": json.dumps(raw) if raw else None, "id": uid},
|
|
||||||
)
|
|
||||||
|
|
||||||
op.drop_column(table, column)
|
|
||||||
op.alter_column(table, f"{column}_text", new_column_name=column)
|
|
||||||
|
|
||||||
else:
|
|
||||||
op.alter_column(
|
|
||||||
table,
|
|
||||||
column,
|
|
||||||
type_=sa.Text(),
|
|
||||||
postgresql_using=f"to_json({column})::text",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.add_column(
|
|
||||||
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
|
|
||||||
)
|
|
||||||
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
|
|
||||||
|
|
||||||
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
|
|
||||||
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
|
|
||||||
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
|
|
||||||
op.add_column(
|
|
||||||
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
|
|
||||||
|
|
||||||
# Convert info (TEXT/JSONField) → JSON
|
|
||||||
_convert_column_to_json("user", "info")
|
|
||||||
# Convert settings (TEXT/JSONField) → JSON
|
|
||||||
_convert_column_to_json("user", "settings")
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"api_key",
|
|
||||||
sa.Column("id", sa.Text(), primary_key=True, unique=True),
|
|
||||||
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
|
|
||||||
sa.Column("key", sa.Text(), unique=True, nullable=False),
|
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
|
||||||
sa.Column("expires_at", sa.BigInteger(), nullable=True),
|
|
||||||
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
conn = op.get_bind()
|
|
||||||
users = conn.execute(
|
|
||||||
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
for uid, oauth_sub in users:
|
|
||||||
if oauth_sub:
|
|
||||||
# Example formats supported:
|
|
||||||
# provider@sub
|
|
||||||
# plain sub (stored as {"oidc": {"sub": sub}})
|
|
||||||
if "@" in oauth_sub:
|
|
||||||
provider, sub = oauth_sub.split("@", 1)
|
|
||||||
else:
|
|
||||||
provider, sub = "oidc", oauth_sub
|
|
||||||
|
|
||||||
oauth_json = json.dumps({provider: {"sub": sub}})
|
|
||||||
conn.execute(
|
|
||||||
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
|
|
||||||
{"oauth": oauth_json, "id": uid},
|
|
||||||
)
|
|
||||||
|
|
||||||
users_with_keys = conn.execute(
|
|
||||||
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
|
|
||||||
).fetchall()
|
|
||||||
now = int(time.time())
|
|
||||||
|
|
||||||
for uid, api_key in users_with_keys:
|
|
||||||
if api_key:
|
|
||||||
conn.execute(
|
|
||||||
sa.text(
|
|
||||||
"""
|
|
||||||
INSERT INTO api_key (id, user_id, key, created_at, updated_at)
|
|
||||||
VALUES (:id, :user_id, :key, :created_at, :updated_at)
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": f"key_{uid}",
|
|
||||||
"user_id": uid,
|
|
||||||
"key": api_key,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if conn.dialect.name == "sqlite":
|
|
||||||
_drop_sqlite_indexes_for_column("user", "api_key", conn)
|
|
||||||
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
|
|
||||||
|
|
||||||
with op.batch_alter_table("user") as batch_op:
|
|
||||||
batch_op.drop_column("api_key")
|
|
||||||
batch_op.drop_column("oauth_sub")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# --- 1. Restore old oauth_sub column ---
|
|
||||||
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
|
|
||||||
|
|
||||||
conn = op.get_bind()
|
|
||||||
users = conn.execute(
|
|
||||||
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
for uid, oauth in users:
|
|
||||||
try:
|
|
||||||
data = json.loads(oauth)
|
|
||||||
provider = list(data.keys())[0]
|
|
||||||
sub = data[provider].get("sub")
|
|
||||||
oauth_sub = f"{provider}@{sub}"
|
|
||||||
except Exception:
|
|
||||||
oauth_sub = None
|
|
||||||
|
|
||||||
conn.execute(
|
|
||||||
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
|
|
||||||
{"oauth_sub": oauth_sub, "id": uid},
|
|
||||||
)
|
|
||||||
|
|
||||||
op.drop_column("user", "oauth")
|
|
||||||
|
|
||||||
# --- 2. Restore api_key field ---
|
|
||||||
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
|
|
||||||
|
|
||||||
# Restore values from api_key
|
|
||||||
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
|
|
||||||
for uid, key in keys:
|
|
||||||
conn.execute(
|
|
||||||
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
|
|
||||||
{"key": key, "id": uid},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Drop new table
|
|
||||||
op.drop_table("api_key")
|
|
||||||
|
|
||||||
with op.batch_alter_table("user") as batch_op:
|
|
||||||
batch_op.drop_column("profile_banner_image_url")
|
|
||||||
batch_op.drop_column("timezone")
|
|
||||||
|
|
||||||
batch_op.drop_column("presence_state")
|
|
||||||
batch_op.drop_column("status_emoji")
|
|
||||||
batch_op.drop_column("status_message")
|
|
||||||
batch_op.drop_column("status_expires_at")
|
|
||||||
|
|
||||||
# Convert info (JSON) → TEXT
|
|
||||||
_convert_column_to_text("user", "info")
|
|
||||||
# Convert settings (JSON) → TEXT
|
|
||||||
_convert_column_to_text("user", "settings")
|
|
||||||
|
|
@ -1,57 +0,0 @@
|
||||||
"""Add chat_file table
|
|
||||||
|
|
||||||
Revision ID: c440947495f3
|
|
||||||
Revises: 81cc2ce44d79
|
|
||||||
Create Date: 2025-12-21 20:27:41.694897
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "c440947495f3"
|
|
||||||
down_revision: Union[str, None] = "81cc2ce44d79"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"chat_file",
|
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"chat_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("chat.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"file_id",
|
|
||||||
sa.Text(),
|
|
||||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("message_id", sa.Text(), nullable=True),
|
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
|
||||||
# indexes
|
|
||||||
sa.Index("ix_chat_file_chat_id", "chat_id"),
|
|
||||||
sa.Index("ix_chat_file_file_id", "file_id"),
|
|
||||||
sa.Index("ix_chat_file_message_id", "message_id"),
|
|
||||||
sa.Index("ix_chat_file_user_id", "user_id"),
|
|
||||||
# unique constraints
|
|
||||||
sa.UniqueConstraint(
|
|
||||||
"chat_id", "file_id", name="uq_chat_file_chat_file"
|
|
||||||
), # prevent duplicate entries
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_table("chat_file")
|
|
||||||
pass
|
|
||||||
|
|
@ -3,11 +3,14 @@ import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.users import UserModel, UserProfileImageResponse, Users
|
from open_webui.models.users import UserModel, Users
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Boolean, Column, String, Text
|
from sqlalchemy import Boolean, Column, String, Text
|
||||||
|
from open_webui.utils.auth import verify_password
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# DB MODEL
|
# DB MODEL
|
||||||
|
|
@ -17,7 +20,7 @@ log = logging.getLogger(__name__)
|
||||||
class Auth(Base):
|
class Auth(Base):
|
||||||
__tablename__ = "auth"
|
__tablename__ = "auth"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True)
|
||||||
email = Column(String)
|
email = Column(String)
|
||||||
password = Column(Text)
|
password = Column(Text)
|
||||||
active = Column(Boolean)
|
active = Column(Boolean)
|
||||||
|
|
@ -44,7 +47,15 @@ class ApiKey(BaseModel):
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class SigninResponse(Token, UserProfileImageResponse):
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
profile_image_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class SigninResponse(Token, UserResponse):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -86,7 +97,7 @@ class AuthsTable:
|
||||||
name: str,
|
name: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = "/user.png",
|
||||||
role: str = "pending",
|
role: str = "pending",
|
||||||
oauth: Optional[dict] = None,
|
oauth_sub: Optional[str] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
log.info("insert_new_auth")
|
log.info("insert_new_auth")
|
||||||
|
|
@ -100,7 +111,7 @@ class AuthsTable:
|
||||||
db.add(result)
|
db.add(result)
|
||||||
|
|
||||||
user = Users.insert_new_user(
|
user = Users.insert_new_user(
|
||||||
id, name, email, profile_image_url, role, oauth=oauth
|
id, name, email, profile_image_url, role, oauth_sub
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
@ -111,9 +122,7 @@ class AuthsTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def authenticate_user(
|
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||||
self, email: str, verify_password: callable
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
log.info(f"authenticate_user: {email}")
|
log.info(f"authenticate_user: {email}")
|
||||||
|
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email)
|
||||||
|
|
@ -124,7 +133,7 @@ class AuthsTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
||||||
if auth:
|
if auth:
|
||||||
if verify_password(auth.password):
|
if verify_password(password, auth.password):
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -4,24 +4,10 @@ import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||||
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
|
||||||
BigInteger,
|
|
||||||
Boolean,
|
|
||||||
Column,
|
|
||||||
ForeignKey,
|
|
||||||
String,
|
|
||||||
Text,
|
|
||||||
JSON,
|
|
||||||
UniqueConstraint,
|
|
||||||
case,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
from sqlalchemy import or_, func, select, and_, text
|
from sqlalchemy import or_, func, select, and_, text
|
||||||
from sqlalchemy.sql import exists
|
from sqlalchemy.sql import exists
|
||||||
|
|
||||||
|
|
@ -33,30 +19,19 @@ from sqlalchemy.sql import exists
|
||||||
class Channel(Base):
|
class Channel(Base):
|
||||||
__tablename__ = "channel"
|
__tablename__ = "channel"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
type = Column(Text, nullable=True)
|
type = Column(Text, nullable=True)
|
||||||
|
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
|
||||||
# Used to indicate if the channel is private (for 'group' type channels)
|
|
||||||
is_private = Column(Boolean, nullable=True)
|
|
||||||
|
|
||||||
data = Column(JSON, nullable=True)
|
data = Column(JSON, nullable=True)
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
access_control = Column(JSON, nullable=True)
|
access_control = Column(JSON, nullable=True)
|
||||||
|
|
||||||
created_at = Column(BigInteger)
|
created_at = Column(BigInteger)
|
||||||
|
|
||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
updated_by = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
archived_at = Column(BigInteger, nullable=True)
|
|
||||||
archived_by = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
deleted_at = Column(BigInteger, nullable=True)
|
|
||||||
deleted_by = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelModel(BaseModel):
|
class ChannelModel(BaseModel):
|
||||||
|
|
@ -64,157 +39,17 @@ class ChannelModel(BaseModel):
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
type: Optional[str] = None
|
type: Optional[str] = None
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
is_private: Optional[bool] = None
|
|
||||||
|
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
||||||
created_at: int # timestamp in epoch (time_ns)
|
created_at: int # timestamp in epoch
|
||||||
|
updated_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch (time_ns)
|
|
||||||
updated_by: Optional[str] = None
|
|
||||||
|
|
||||||
archived_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
archived_by: Optional[str] = None
|
|
||||||
|
|
||||||
deleted_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
deleted_by: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelMember(Base):
|
|
||||||
__tablename__ = "channel_member"
|
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
|
||||||
channel_id = Column(Text, nullable=False)
|
|
||||||
user_id = Column(Text, nullable=False)
|
|
||||||
|
|
||||||
role = Column(Text, nullable=True)
|
|
||||||
status = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
is_active = Column(Boolean, nullable=False, default=True)
|
|
||||||
|
|
||||||
is_channel_muted = Column(Boolean, nullable=False, default=False)
|
|
||||||
is_channel_pinned = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
data = Column(JSON, nullable=True)
|
|
||||||
meta = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
invited_at = Column(BigInteger, nullable=True)
|
|
||||||
invited_by = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
joined_at = Column(BigInteger)
|
|
||||||
left_at = Column(BigInteger, nullable=True)
|
|
||||||
|
|
||||||
last_read_at = Column(BigInteger, nullable=True)
|
|
||||||
|
|
||||||
created_at = Column(BigInteger)
|
|
||||||
updated_at = Column(BigInteger)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelMemberModel(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str
|
|
||||||
channel_id: str
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
role: Optional[str] = None
|
|
||||||
status: Optional[str] = None
|
|
||||||
|
|
||||||
is_active: bool = True
|
|
||||||
|
|
||||||
is_channel_muted: bool = False
|
|
||||||
is_channel_pinned: bool = False
|
|
||||||
|
|
||||||
data: Optional[dict] = None
|
|
||||||
meta: Optional[dict] = None
|
|
||||||
|
|
||||||
invited_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
invited_by: Optional[str] = None
|
|
||||||
|
|
||||||
joined_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
left_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
|
|
||||||
last_read_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
|
|
||||||
created_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
updated_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelFile(Base):
|
|
||||||
__tablename__ = "channel_file"
|
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
|
||||||
user_id = Column(Text, nullable=False)
|
|
||||||
|
|
||||||
channel_id = Column(
|
|
||||||
Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
message_id = Column(
|
|
||||||
Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True
|
|
||||||
)
|
|
||||||
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
|
||||||
|
|
||||||
created_at = Column(BigInteger, nullable=False)
|
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelFileModel(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str
|
|
||||||
|
|
||||||
channel_id: str
|
|
||||||
file_id: str
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
created_at: int # timestamp in epoch (time_ns)
|
|
||||||
updated_at: int # timestamp in epoch (time_ns)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelWebhook(Base):
|
|
||||||
__tablename__ = "channel_webhook"
|
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
|
||||||
channel_id = Column(Text, nullable=False)
|
|
||||||
user_id = Column(Text, nullable=False)
|
|
||||||
|
|
||||||
name = Column(Text, nullable=False)
|
|
||||||
profile_image_url = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
token = Column(Text, nullable=False)
|
|
||||||
last_used_at = Column(BigInteger, nullable=True)
|
|
||||||
|
|
||||||
created_at = Column(BigInteger, nullable=False)
|
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelWebhookModel(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str
|
|
||||||
channel_id: str
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
name: str
|
|
||||||
profile_image_url: Optional[str] = None
|
|
||||||
|
|
||||||
token: str
|
|
||||||
last_used_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
|
|
||||||
created_at: int # timestamp in epoch (time_ns)
|
|
||||||
updated_at: int # timestamp in epoch (time_ns)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -223,94 +58,26 @@ class ChannelWebhookModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChannelResponse(ChannelModel):
|
class ChannelResponse(ChannelModel):
|
||||||
is_manager: bool = False
|
|
||||||
write_access: bool = False
|
write_access: bool = False
|
||||||
|
|
||||||
user_count: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelForm(BaseModel):
|
class ChannelForm(BaseModel):
|
||||||
name: str = ""
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
is_private: Optional[bool] = None
|
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
group_ids: Optional[list[str]] = None
|
|
||||||
user_ids: Optional[list[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CreateChannelForm(ChannelForm):
|
|
||||||
type: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelTable:
|
class ChannelTable:
|
||||||
|
|
||||||
def _collect_unique_user_ids(
|
|
||||||
self,
|
|
||||||
invited_by: str,
|
|
||||||
user_ids: Optional[list[str]] = None,
|
|
||||||
group_ids: Optional[list[str]] = None,
|
|
||||||
) -> set[str]:
|
|
||||||
"""
|
|
||||||
Collect unique user ids from:
|
|
||||||
- invited_by
|
|
||||||
- user_ids
|
|
||||||
- each group in group_ids
|
|
||||||
Returns a set for efficient SQL diffing.
|
|
||||||
"""
|
|
||||||
users = set(user_ids or [])
|
|
||||||
users.add(invited_by)
|
|
||||||
|
|
||||||
for group_id in group_ids or []:
|
|
||||||
users.update(Groups.get_group_user_ids_by_id(group_id))
|
|
||||||
|
|
||||||
return users
|
|
||||||
|
|
||||||
def _create_membership_models(
|
|
||||||
self,
|
|
||||||
channel_id: str,
|
|
||||||
invited_by: str,
|
|
||||||
user_ids: set[str],
|
|
||||||
) -> list[ChannelMember]:
|
|
||||||
"""
|
|
||||||
Takes a set of NEW user IDs (already filtered to exclude existing members).
|
|
||||||
Returns ORM ChannelMember objects to be added.
|
|
||||||
"""
|
|
||||||
now = int(time.time_ns())
|
|
||||||
memberships = []
|
|
||||||
|
|
||||||
for uid in user_ids:
|
|
||||||
model = ChannelMemberModel(
|
|
||||||
**{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"channel_id": channel_id,
|
|
||||||
"user_id": uid,
|
|
||||||
"status": "joined",
|
|
||||||
"is_active": True,
|
|
||||||
"is_channel_muted": False,
|
|
||||||
"is_channel_pinned": False,
|
|
||||||
"invited_at": now,
|
|
||||||
"invited_by": invited_by,
|
|
||||||
"joined_at": now,
|
|
||||||
"left_at": None,
|
|
||||||
"last_read_at": now,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
memberships.append(ChannelMember(**model.model_dump()))
|
|
||||||
|
|
||||||
return memberships
|
|
||||||
|
|
||||||
def insert_new_channel(
|
def insert_new_channel(
|
||||||
self, form_data: CreateChannelForm, user_id: str
|
self, type: Optional[str], form_data: ChannelForm, user_id: str
|
||||||
) -> Optional[ChannelModel]:
|
) -> Optional[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
channel = ChannelModel(
|
channel = ChannelModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
"type": form_data.type if form_data.type else None,
|
"type": type,
|
||||||
"name": form_data.name.lower(),
|
"name": form_data.name.lower(),
|
||||||
"id": str(uuid.uuid4()),
|
"id": str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
@ -318,21 +85,9 @@ class ChannelTable:
|
||||||
"updated_at": int(time.time_ns()),
|
"updated_at": int(time.time_ns()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
new_channel = Channel(**channel.model_dump())
|
new_channel = Channel(**channel.model_dump())
|
||||||
|
|
||||||
if form_data.type in ["group", "dm"]:
|
|
||||||
users = self._collect_unique_user_ids(
|
|
||||||
invited_by=user_id,
|
|
||||||
user_ids=form_data.user_ids,
|
|
||||||
group_ids=form_data.group_ids,
|
|
||||||
)
|
|
||||||
memberships = self._create_membership_models(
|
|
||||||
channel_id=new_channel.id,
|
|
||||||
invited_by=user_id,
|
|
||||||
user_ids=users,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add_all(memberships)
|
|
||||||
db.add(new_channel)
|
db.add(new_channel)
|
||||||
db.commit()
|
db.commit()
|
||||||
return channel
|
return channel
|
||||||
|
|
@ -342,481 +97,22 @@ class ChannelTable:
|
||||||
channels = db.query(Channel).all()
|
channels = db.query(Channel).all()
|
||||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||||
|
|
||||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
def get_channels_by_user_id(
|
||||||
group_ids = filter.get("group_ids", [])
|
self, user_id: str, permission: str = "read"
|
||||||
user_id = filter.get("user_id")
|
) -> list[ChannelModel]:
|
||||||
|
channels = self.get_channels()
|
||||||
dialect_name = db.bind.dialect.name
|
return [
|
||||||
|
channel
|
||||||
# Public access
|
for channel in channels
|
||||||
conditions = []
|
if channel.user_id == user_id
|
||||||
if group_ids or user_id:
|
or has_access(user_id, permission, channel.access_control)
|
||||||
conditions.extend(
|
]
|
||||||
[
|
|
||||||
Channel.access_control.is_(None),
|
|
||||||
cast(Channel.access_control, String) == "null",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# User-level permission
|
|
||||||
if user_id:
|
|
||||||
conditions.append(Channel.user_id == user_id)
|
|
||||||
|
|
||||||
# Group-level permission
|
|
||||||
if group_ids:
|
|
||||||
group_conditions = []
|
|
||||||
for gid in group_ids:
|
|
||||||
if dialect_name == "sqlite":
|
|
||||||
group_conditions.append(
|
|
||||||
Channel.access_control[permission]["group_ids"].contains([gid])
|
|
||||||
)
|
|
||||||
elif dialect_name == "postgresql":
|
|
||||||
group_conditions.append(
|
|
||||||
cast(
|
|
||||||
Channel.access_control[permission]["group_ids"],
|
|
||||||
JSONB,
|
|
||||||
).contains([gid])
|
|
||||||
)
|
|
||||||
conditions.append(or_(*group_conditions))
|
|
||||||
|
|
||||||
if conditions:
|
|
||||||
query = query.filter(or_(*conditions))
|
|
||||||
|
|
||||||
return query
|
|
||||||
|
|
||||||
def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
user_group_ids = [
|
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
membership_channels = (
|
|
||||||
db.query(Channel)
|
|
||||||
.join(ChannelMember, Channel.id == ChannelMember.channel_id)
|
|
||||||
.filter(
|
|
||||||
Channel.deleted_at.is_(None),
|
|
||||||
Channel.archived_at.is_(None),
|
|
||||||
Channel.type.in_(["group", "dm"]),
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
ChannelMember.is_active.is_(True),
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
query = db.query(Channel).filter(
|
|
||||||
Channel.deleted_at.is_(None),
|
|
||||||
Channel.archived_at.is_(None),
|
|
||||||
or_(
|
|
||||||
Channel.type.is_(None), # True NULL/None
|
|
||||||
Channel.type == "", # Empty string
|
|
||||||
and_(Channel.type != "group", Channel.type != "dm"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
query = self._has_permission(
|
|
||||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}
|
|
||||||
)
|
|
||||||
|
|
||||||
standard_channels = query.all()
|
|
||||||
|
|
||||||
all_channels = membership_channels + standard_channels
|
|
||||||
return [ChannelModel.model_validate(c) for c in all_channels]
|
|
||||||
|
|
||||||
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
# Ensure uniqueness in case a list with duplicates is passed
|
|
||||||
unique_user_ids = list(set(user_ids))
|
|
||||||
|
|
||||||
match_count = func.sum(
|
|
||||||
case(
|
|
||||||
(ChannelMember.user_id.in_(unique_user_ids), 1),
|
|
||||||
else_=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
subquery = (
|
|
||||||
db.query(ChannelMember.channel_id)
|
|
||||||
.group_by(ChannelMember.channel_id)
|
|
||||||
# 1. Channel must have exactly len(user_ids) members
|
|
||||||
.having(func.count(ChannelMember.user_id) == len(unique_user_ids))
|
|
||||||
# 2. All those members must be in unique_user_ids
|
|
||||||
.having(match_count == len(unique_user_ids))
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
|
|
||||||
channel = (
|
|
||||||
db.query(Channel)
|
|
||||||
.filter(
|
|
||||||
Channel.id.in_(subquery),
|
|
||||||
Channel.type == "dm",
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChannelModel.model_validate(channel) if channel else None
|
|
||||||
|
|
||||||
def add_members_to_channel(
|
|
||||||
self,
|
|
||||||
channel_id: str,
|
|
||||||
invited_by: str,
|
|
||||||
user_ids: Optional[list[str]] = None,
|
|
||||||
group_ids: Optional[list[str]] = None,
|
|
||||||
) -> list[ChannelMemberModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
# 1. Collect all user_ids including groups + inviter
|
|
||||||
requested_users = self._collect_unique_user_ids(
|
|
||||||
invited_by, user_ids, group_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_users = {
|
|
||||||
row.user_id
|
|
||||||
for row in db.query(ChannelMember.user_id)
|
|
||||||
.filter(ChannelMember.channel_id == channel_id)
|
|
||||||
.all()
|
|
||||||
}
|
|
||||||
|
|
||||||
new_user_ids = requested_users - existing_users
|
|
||||||
if not new_user_ids:
|
|
||||||
return [] # Nothing to add
|
|
||||||
|
|
||||||
new_memberships = self._create_membership_models(
|
|
||||||
channel_id, invited_by, new_user_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add_all(new_memberships)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return [
|
|
||||||
ChannelMemberModel.model_validate(membership)
|
|
||||||
for membership in new_memberships
|
|
||||||
]
|
|
||||||
|
|
||||||
def remove_members_from_channel(
|
|
||||||
self,
|
|
||||||
channel_id: str,
|
|
||||||
user_ids: list[str],
|
|
||||||
) -> int:
|
|
||||||
with get_db() as db:
|
|
||||||
result = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id.in_(user_ids),
|
|
||||||
)
|
|
||||||
.delete(synchronize_session=False)
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
return result # number of rows deleted
|
|
||||||
|
|
||||||
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool:
|
|
||||||
with get_db() as db:
|
|
||||||
# Check if the user is the creator of the channel
|
|
||||||
# or has a 'manager' role in ChannelMember
|
|
||||||
channel = db.query(Channel).filter(Channel.id == channel_id).first()
|
|
||||||
if channel and channel.user_id == user_id:
|
|
||||||
return True
|
|
||||||
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
ChannelMember.role == "manager",
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return membership is not None
|
|
||||||
|
|
||||||
def join_channel(
|
|
||||||
self, channel_id: str, user_id: str
|
|
||||||
) -> Optional[ChannelMemberModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
# Check if the membership already exists
|
|
||||||
existing_membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing_membership:
|
|
||||||
return ChannelMemberModel.model_validate(existing_membership)
|
|
||||||
|
|
||||||
# Create new membership
|
|
||||||
channel_member = ChannelMemberModel(
|
|
||||||
**{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"channel_id": channel_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"status": "joined",
|
|
||||||
"is_active": True,
|
|
||||||
"is_channel_muted": False,
|
|
||||||
"is_channel_pinned": False,
|
|
||||||
"joined_at": int(time.time_ns()),
|
|
||||||
"left_at": None,
|
|
||||||
"last_read_at": int(time.time_ns()),
|
|
||||||
"created_at": int(time.time_ns()),
|
|
||||||
"updated_at": int(time.time_ns()),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
new_membership = ChannelMember(**channel_member.model_dump())
|
|
||||||
|
|
||||||
db.add(new_membership)
|
|
||||||
db.commit()
|
|
||||||
return channel_member
|
|
||||||
|
|
||||||
def leave_channel(self, channel_id: str, user_id: str) -> bool:
|
|
||||||
with get_db() as db:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not membership:
|
|
||||||
return False
|
|
||||||
|
|
||||||
membership.status = "left"
|
|
||||||
membership.is_active = False
|
|
||||||
membership.left_at = int(time.time_ns())
|
|
||||||
membership.updated_at = int(time.time_ns())
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_member_by_channel_and_user_id(
|
|
||||||
self, channel_id: str, user_id: str
|
|
||||||
) -> Optional[ChannelMemberModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return ChannelMemberModel.model_validate(membership) if membership else None
|
|
||||||
|
|
||||||
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
memberships = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(ChannelMember.channel_id == channel_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
ChannelMemberModel.model_validate(membership)
|
|
||||||
for membership in memberships
|
|
||||||
]
|
|
||||||
|
|
||||||
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
|
|
||||||
with get_db() as db:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not membership:
|
|
||||||
return False
|
|
||||||
|
|
||||||
membership.is_channel_pinned = is_pinned
|
|
||||||
membership.updated_at = int(time.time_ns())
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool:
|
|
||||||
with get_db() as db:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not membership:
|
|
||||||
return False
|
|
||||||
|
|
||||||
membership.last_read_at = int(time.time_ns())
|
|
||||||
membership.updated_at = int(time.time_ns())
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def update_member_active_status(
|
|
||||||
self, channel_id: str, user_id: str, is_active: bool
|
|
||||||
) -> bool:
|
|
||||||
with get_db() as db:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not membership:
|
|
||||||
return False
|
|
||||||
|
|
||||||
membership.is_active = is_active
|
|
||||||
membership.updated_at = int(time.time_ns())
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
|
|
||||||
with get_db() as db:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return membership is not None
|
|
||||||
|
|
||||||
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||||
return ChannelModel.model_validate(channel) if channel else None
|
return ChannelModel.model_validate(channel) if channel else None
|
||||||
|
|
||||||
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
channel_files = (
|
|
||||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
|
||||||
)
|
|
||||||
channel_ids = [cf.channel_id for cf in channel_files]
|
|
||||||
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
|
|
||||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
|
||||||
|
|
||||||
def get_channels_by_file_id_and_user_id(
|
|
||||||
self, file_id: str, user_id: str
|
|
||||||
) -> list[ChannelModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
# 1. Determine which channels have this file
|
|
||||||
channel_file_rows = (
|
|
||||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
|
||||||
)
|
|
||||||
channel_ids = [row.channel_id for row in channel_file_rows]
|
|
||||||
|
|
||||||
if not channel_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 2. Load all channel rows that still exist
|
|
||||||
channels = (
|
|
||||||
db.query(Channel)
|
|
||||||
.filter(
|
|
||||||
Channel.id.in_(channel_ids),
|
|
||||||
Channel.deleted_at.is_(None),
|
|
||||||
Channel.archived_at.is_(None),
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
if not channels:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Preload user's group membership
|
|
||||||
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)]
|
|
||||||
|
|
||||||
allowed_channels = []
|
|
||||||
|
|
||||||
for channel in channels:
|
|
||||||
# --- Case A: group or dm => user must be an active member ---
|
|
||||||
if channel.type in ["group", "dm"]:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == channel.id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
ChannelMember.is_active.is_(True),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if membership:
|
|
||||||
allowed_channels.append(ChannelModel.model_validate(channel))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# --- Case B: standard channel => rely on ACL permissions ---
|
|
||||||
query = db.query(Channel).filter(Channel.id == channel.id)
|
|
||||||
|
|
||||||
query = self._has_permission(
|
|
||||||
db,
|
|
||||||
query,
|
|
||||||
{"user_id": user_id, "group_ids": user_group_ids},
|
|
||||||
permission="read",
|
|
||||||
)
|
|
||||||
|
|
||||||
allowed = query.first()
|
|
||||||
if allowed:
|
|
||||||
allowed_channels.append(ChannelModel.model_validate(allowed))
|
|
||||||
|
|
||||||
return allowed_channels
|
|
||||||
|
|
||||||
def get_channel_by_id_and_user_id(
|
|
||||||
self, id: str, user_id: str
|
|
||||||
) -> Optional[ChannelModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
# Fetch the channel
|
|
||||||
channel: Channel = (
|
|
||||||
db.query(Channel)
|
|
||||||
.filter(
|
|
||||||
Channel.id == id,
|
|
||||||
Channel.deleted_at.is_(None),
|
|
||||||
Channel.archived_at.is_(None),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not channel:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# If the channel is a group or dm, read access requires membership (active)
|
|
||||||
if channel.type in ["group", "dm"]:
|
|
||||||
membership = (
|
|
||||||
db.query(ChannelMember)
|
|
||||||
.filter(
|
|
||||||
ChannelMember.channel_id == id,
|
|
||||||
ChannelMember.user_id == user_id,
|
|
||||||
ChannelMember.is_active.is_(True),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if membership:
|
|
||||||
return ChannelModel.model_validate(channel)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# For channels that are NOT group/dm, fall back to ACL-based read access
|
|
||||||
query = db.query(Channel).filter(Channel.id == id)
|
|
||||||
|
|
||||||
# Determine user groups
|
|
||||||
user_group_ids = [
|
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply ACL rules
|
|
||||||
query = self._has_permission(
|
|
||||||
db,
|
|
||||||
query,
|
|
||||||
{"user_id": user_id, "group_ids": user_group_ids},
|
|
||||||
permission="read",
|
|
||||||
)
|
|
||||||
|
|
||||||
channel_allowed = query.first()
|
|
||||||
return (
|
|
||||||
ChannelModel.model_validate(channel_allowed)
|
|
||||||
if channel_allowed
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_channel_by_id(
|
def update_channel_by_id(
|
||||||
self, id: str, form_data: ChannelForm
|
self, id: str, form_data: ChannelForm
|
||||||
) -> Optional[ChannelModel]:
|
) -> Optional[ChannelModel]:
|
||||||
|
|
@ -826,77 +122,14 @@ class ChannelTable:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
channel.name = form_data.name
|
channel.name = form_data.name
|
||||||
channel.description = form_data.description
|
|
||||||
channel.is_private = form_data.is_private
|
|
||||||
|
|
||||||
channel.data = form_data.data
|
channel.data = form_data.data
|
||||||
channel.meta = form_data.meta
|
channel.meta = form_data.meta
|
||||||
|
|
||||||
channel.access_control = form_data.access_control
|
channel.access_control = form_data.access_control
|
||||||
channel.updated_at = int(time.time_ns())
|
channel.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return ChannelModel.model_validate(channel) if channel else None
|
return ChannelModel.model_validate(channel) if channel else None
|
||||||
|
|
||||||
def add_file_to_channel_by_id(
|
|
||||||
self, channel_id: str, file_id: str, user_id: str
|
|
||||||
) -> Optional[ChannelFileModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
channel_file = ChannelFileModel(
|
|
||||||
**{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"channel_id": channel_id,
|
|
||||||
"file_id": file_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"created_at": int(time.time()),
|
|
||||||
"updated_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = ChannelFile(**channel_file.model_dump())
|
|
||||||
db.add(result)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(result)
|
|
||||||
if result:
|
|
||||||
return ChannelFileModel.model_validate(result)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def set_file_message_id_in_channel_by_id(
|
|
||||||
self, channel_id: str, file_id: str, message_id: str
|
|
||||||
) -> bool:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
channel_file = (
|
|
||||||
db.query(ChannelFile)
|
|
||||||
.filter_by(channel_id=channel_id, file_id=file_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not channel_file:
|
|
||||||
return False
|
|
||||||
|
|
||||||
channel_file.message_id = message_id
|
|
||||||
channel_file.updated_at = int(time.time())
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
db.query(ChannelFile).filter_by(
|
|
||||||
channel_id=channel_id, file_id=file_id
|
|
||||||
).delete()
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def delete_channel_by_id(self, id: str):
|
def delete_channel_by_id(self, id: str):
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
db.query(Channel).filter(Channel.id == id).delete()
|
db.query(Channel).filter(Channel.id == id).delete()
|
||||||
|
|
|
||||||
|
|
@ -7,20 +7,10 @@ from typing import Optional
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.tags import TagModel, Tag, Tags
|
from open_webui.models.tags import TagModel, Tag, Tags
|
||||||
from open_webui.models.folders import Folders
|
from open_webui.models.folders import Folders
|
||||||
from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import (
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, Index
|
||||||
BigInteger,
|
|
||||||
Boolean,
|
|
||||||
Column,
|
|
||||||
ForeignKey,
|
|
||||||
String,
|
|
||||||
Text,
|
|
||||||
JSON,
|
|
||||||
Index,
|
|
||||||
UniqueConstraint,
|
|
||||||
)
|
|
||||||
from sqlalchemy import or_, func, select, and_, text
|
from sqlalchemy import or_, func, select, and_, text
|
||||||
from sqlalchemy.sql import exists
|
from sqlalchemy.sql import exists
|
||||||
from sqlalchemy.sql.expression import bindparam
|
from sqlalchemy.sql.expression import bindparam
|
||||||
|
|
@ -30,12 +20,13 @@ from sqlalchemy.sql.expression import bindparam
|
||||||
####################
|
####################
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
class Chat(Base):
|
class Chat(Base):
|
||||||
__tablename__ = "chat"
|
__tablename__ = "chat"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
title = Column(Text)
|
title = Column(Text)
|
||||||
chat = Column(JSON)
|
chat = Column(JSON)
|
||||||
|
|
@ -84,38 +75,6 @@ class ChatModel(BaseModel):
|
||||||
folder_id: Optional[str] = None
|
folder_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatFile(Base):
|
|
||||||
__tablename__ = "chat_file"
|
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
|
||||||
user_id = Column(Text, nullable=False)
|
|
||||||
|
|
||||||
chat_id = Column(Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False)
|
|
||||||
message_id = Column(Text, nullable=True)
|
|
||||||
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
|
||||||
|
|
||||||
created_at = Column(BigInteger, nullable=False)
|
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint("chat_id", "file_id", name="uq_chat_file_chat_file"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatFileModel(BaseModel):
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
chat_id: str
|
|
||||||
message_id: Optional[str] = None
|
|
||||||
file_id: str
|
|
||||||
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Forms
|
# Forms
|
||||||
####################
|
####################
|
||||||
|
|
@ -133,10 +92,6 @@ class ChatImportForm(ChatForm):
|
||||||
updated_at: Optional[int] = None
|
updated_at: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatsImportForm(BaseModel):
|
|
||||||
chats: list[ChatImportForm]
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTitleMessagesForm(BaseModel):
|
class ChatTitleMessagesForm(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
|
|
@ -167,77 +122,7 @@ class ChatTitleIdResponse(BaseModel):
|
||||||
created_at: int
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
class ChatListResponse(BaseModel):
|
|
||||||
items: list[ChatModel]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class ChatUsageStatsResponse(BaseModel):
|
|
||||||
id: str # chat id
|
|
||||||
|
|
||||||
models: dict = {} # models used in the chat with their usage counts
|
|
||||||
message_count: int # number of messages in the chat
|
|
||||||
|
|
||||||
history_models: dict = {} # models used in the chat history with their usage counts
|
|
||||||
history_message_count: int # number of messages in the chat history
|
|
||||||
history_user_message_count: int # number of user messages in the chat history
|
|
||||||
history_assistant_message_count: (
|
|
||||||
int # number of assistant messages in the chat history
|
|
||||||
)
|
|
||||||
|
|
||||||
average_response_time: (
|
|
||||||
float # average response time of assistant messages in seconds
|
|
||||||
)
|
|
||||||
average_user_message_content_length: (
|
|
||||||
float # average length of user message contents
|
|
||||||
)
|
|
||||||
average_assistant_message_content_length: (
|
|
||||||
float # average length of assistant message contents
|
|
||||||
)
|
|
||||||
|
|
||||||
tags: list[str] = [] # tags associated with the chat
|
|
||||||
|
|
||||||
last_message_at: int # timestamp of the last message
|
|
||||||
updated_at: int
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatUsageStatsListResponse(BaseModel):
|
|
||||||
items: list[ChatUsageStatsResponse]
|
|
||||||
total: int
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTable:
|
class ChatTable:
|
||||||
def _clean_null_bytes(self, obj):
|
|
||||||
"""Recursively remove null bytes from strings in dict/list structures."""
|
|
||||||
return sanitize_data_for_db(obj)
|
|
||||||
|
|
||||||
def _sanitize_chat_row(self, chat_item):
|
|
||||||
"""
|
|
||||||
Clean a Chat SQLAlchemy model's title + chat JSON,
|
|
||||||
and return True if anything changed.
|
|
||||||
"""
|
|
||||||
changed = False
|
|
||||||
|
|
||||||
# Clean title
|
|
||||||
if chat_item.title:
|
|
||||||
cleaned = self._clean_null_bytes(chat_item.title)
|
|
||||||
if cleaned != chat_item.title:
|
|
||||||
chat_item.title = cleaned
|
|
||||||
changed = True
|
|
||||||
|
|
||||||
# Clean JSON
|
|
||||||
if chat_item.chat:
|
|
||||||
cleaned = self._clean_null_bytes(chat_item.chat)
|
|
||||||
if cleaned != chat_item.chat:
|
|
||||||
chat_item.chat = cleaned
|
|
||||||
changed = True
|
|
||||||
|
|
||||||
return changed
|
|
||||||
|
|
||||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
|
|
@ -245,76 +130,68 @@ class ChatTable:
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"title": self._clean_null_bytes(
|
"title": (
|
||||||
form_data.chat["title"]
|
form_data.chat["title"]
|
||||||
if "title" in form_data.chat
|
if "title" in form_data.chat
|
||||||
else "New Chat"
|
else "New Chat"
|
||||||
),
|
),
|
||||||
"chat": self._clean_null_bytes(form_data.chat),
|
"chat": form_data.chat,
|
||||||
"folder_id": form_data.folder_id,
|
"folder_id": form_data.folder_id,
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_item = Chat(**chat.model_dump())
|
result = Chat(**chat.model_dump())
|
||||||
db.add(chat_item)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(chat_item)
|
db.refresh(result)
|
||||||
return ChatModel.model_validate(chat_item) if chat_item else None
|
return ChatModel.model_validate(result) if result else None
|
||||||
|
|
||||||
def _chat_import_form_to_chat_model(
|
def import_chat(
|
||||||
self, user_id: str, form_data: ChatImportForm
|
self, user_id: str, form_data: ChatImportForm
|
||||||
) -> ChatModel:
|
) -> Optional[ChatModel]:
|
||||||
id = str(uuid.uuid4())
|
|
||||||
chat = ChatModel(
|
|
||||||
**{
|
|
||||||
"id": id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"title": self._clean_null_bytes(
|
|
||||||
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
|
|
||||||
),
|
|
||||||
"chat": self._clean_null_bytes(form_data.chat),
|
|
||||||
"meta": form_data.meta,
|
|
||||||
"pinned": form_data.pinned,
|
|
||||||
"folder_id": form_data.folder_id,
|
|
||||||
"created_at": (
|
|
||||||
form_data.created_at if form_data.created_at else int(time.time())
|
|
||||||
),
|
|
||||||
"updated_at": (
|
|
||||||
form_data.updated_at if form_data.updated_at else int(time.time())
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return chat
|
|
||||||
|
|
||||||
def import_chats(
|
|
||||||
self, user_id: str, chat_import_forms: list[ChatImportForm]
|
|
||||||
) -> list[ChatModel]:
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chats = []
|
id = str(uuid.uuid4())
|
||||||
|
chat = ChatModel(
|
||||||
|
**{
|
||||||
|
"id": id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"title": (
|
||||||
|
form_data.chat["title"]
|
||||||
|
if "title" in form_data.chat
|
||||||
|
else "New Chat"
|
||||||
|
),
|
||||||
|
"chat": form_data.chat,
|
||||||
|
"meta": form_data.meta,
|
||||||
|
"pinned": form_data.pinned,
|
||||||
|
"folder_id": form_data.folder_id,
|
||||||
|
"created_at": (
|
||||||
|
form_data.created_at
|
||||||
|
if form_data.created_at
|
||||||
|
else int(time.time())
|
||||||
|
),
|
||||||
|
"updated_at": (
|
||||||
|
form_data.updated_at
|
||||||
|
if form_data.updated_at
|
||||||
|
else int(time.time())
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for form_data in chat_import_forms:
|
result = Chat(**chat.model_dump())
|
||||||
chat = self._chat_import_form_to_chat_model(user_id, form_data)
|
db.add(result)
|
||||||
chats.append(Chat(**chat.model_dump()))
|
|
||||||
|
|
||||||
db.add_all(chats)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return [ChatModel.model_validate(chat) for chat in chats]
|
db.refresh(result)
|
||||||
|
return ChatModel.model_validate(result) if result else None
|
||||||
|
|
||||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chat_item = db.get(Chat, id)
|
chat_item = db.get(Chat, id)
|
||||||
chat_item.chat = self._clean_null_bytes(chat)
|
chat_item.chat = chat
|
||||||
chat_item.title = (
|
chat_item.title = chat["title"] if "title" in chat else "New Chat"
|
||||||
self._clean_null_bytes(chat["title"])
|
|
||||||
if "title" in chat
|
|
||||||
else "New Chat"
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_item.updated_at = int(time.time())
|
chat_item.updated_at = int(time.time())
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(chat_item)
|
db.refresh(chat_item)
|
||||||
|
|
||||||
|
|
@ -384,7 +261,7 @@ class ChatTable:
|
||||||
|
|
||||||
# Sanitize message content for null characters before upserting
|
# Sanitize message content for null characters before upserting
|
||||||
if isinstance(message.get("content"), str):
|
if isinstance(message.get("content"), str):
|
||||||
message["content"] = sanitize_text_for_db(message["content"])
|
message["content"] = message["content"].replace("\x00", "")
|
||||||
|
|
||||||
chat = chat.chat
|
chat = chat.chat
|
||||||
history = chat.get("history", {})
|
history = chat.get("history", {})
|
||||||
|
|
@ -420,27 +297,6 @@ class ChatTable:
|
||||||
chat["history"] = history
|
chat["history"] = history
|
||||||
return self.update_chat_by_id(id, chat)
|
return self.update_chat_by_id(id, chat)
|
||||||
|
|
||||||
def add_message_files_by_id_and_message_id(
|
|
||||||
self, id: str, message_id: str, files: list[dict]
|
|
||||||
) -> list[dict]:
|
|
||||||
chat = self.get_chat_by_id(id)
|
|
||||||
if chat is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
chat = chat.chat
|
|
||||||
history = chat.get("history", {})
|
|
||||||
|
|
||||||
message_files = []
|
|
||||||
|
|
||||||
if message_id in history.get("messages", {}):
|
|
||||||
message_files = history["messages"][message_id].get("files", [])
|
|
||||||
message_files = message_files + files
|
|
||||||
history["messages"][message_id]["files"] = message_files
|
|
||||||
|
|
||||||
chat["history"] = history
|
|
||||||
self.update_chat_by_id(id, chat)
|
|
||||||
return message_files
|
|
||||||
|
|
||||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# Get the existing chat to share
|
# Get the existing chat to share
|
||||||
|
|
@ -549,7 +405,6 @@ class ChatTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chat = db.get(Chat, id)
|
chat = db.get(Chat, id)
|
||||||
chat.archived = not chat.archived
|
chat.archived = not chat.archived
|
||||||
chat.folder_id = None
|
|
||||||
chat.updated_at = int(time.time())
|
chat.updated_at = int(time.time())
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(chat)
|
db.refresh(chat)
|
||||||
|
|
@ -706,15 +561,8 @@ class ChatTable:
|
||||||
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chat_item = db.get(Chat, id)
|
chat = db.get(Chat, id)
|
||||||
if chat_item is None:
|
return ChatModel.model_validate(chat)
|
||||||
return None
|
|
||||||
|
|
||||||
if self._sanitize_chat_row(chat_item):
|
|
||||||
db.commit()
|
|
||||||
db.refresh(chat_item)
|
|
||||||
|
|
||||||
return ChatModel.model_validate(chat_item)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -749,31 +597,14 @@ class ChatTable:
|
||||||
)
|
)
|
||||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||||
|
|
||||||
def get_chats_by_user_id(
|
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||||
self, user_id: str, skip: Optional[int] = None, limit: Optional[int] = None
|
|
||||||
) -> ChatListResponse:
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
query = (
|
all_chats = (
|
||||||
db.query(Chat)
|
db.query(Chat)
|
||||||
.filter_by(user_id=user_id)
|
.filter_by(user_id=user_id)
|
||||||
.order_by(Chat.updated_at.desc())
|
.order_by(Chat.updated_at.desc())
|
||||||
)
|
)
|
||||||
|
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
if skip is not None:
|
|
||||||
query = query.offset(skip)
|
|
||||||
if limit is not None:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
all_chats = query.all()
|
|
||||||
|
|
||||||
return ChatListResponse(
|
|
||||||
**{
|
|
||||||
"items": [ChatModel.model_validate(chat) for chat in all_chats],
|
|
||||||
"total": total,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -804,7 +635,7 @@ class ChatTable:
|
||||||
"""
|
"""
|
||||||
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
||||||
"""
|
"""
|
||||||
search_text = sanitize_text_for_db(search_text).lower().strip()
|
search_text = search_text.replace("\u0000", "").lower().strip()
|
||||||
|
|
||||||
if not search_text:
|
if not search_text:
|
||||||
return self.get_chat_list_by_user_id(
|
return self.get_chat_list_by_user_id(
|
||||||
|
|
@ -936,30 +767,24 @@ class ChatTable:
|
||||||
elif dialect_name == "postgresql":
|
elif dialect_name == "postgresql":
|
||||||
# PostgreSQL doesn't allow null bytes in text. We filter those out by checking
|
# PostgreSQL doesn't allow null bytes in text. We filter those out by checking
|
||||||
# the JSON representation for \u0000 before attempting text extraction
|
# the JSON representation for \u0000 before attempting text extraction
|
||||||
|
postgres_content_sql = (
|
||||||
# Safety filter: JSON field must not contain \u0000
|
"EXISTS ("
|
||||||
query = query.filter(text("Chat.chat::text NOT LIKE '%\\\\u0000%'"))
|
" SELECT 1 "
|
||||||
|
" FROM json_array_elements(Chat.chat->'messages') AS message "
|
||||||
# Safety filter: title must not contain actual null bytes
|
" WHERE message->'content' IS NOT NULL "
|
||||||
query = query.filter(text("Chat.title::text NOT LIKE '%\\x00%'"))
|
" AND (message->'content')::text NOT LIKE '%\\u0000%' "
|
||||||
|
" AND LOWER(message->>'content') LIKE '%' || :content_key || '%'"
|
||||||
postgres_content_sql = """
|
")"
|
||||||
EXISTS (
|
|
||||||
SELECT 1
|
|
||||||
FROM json_array_elements(Chat.chat->'messages') AS message
|
|
||||||
WHERE json_typeof(message->'content') = 'string'
|
|
||||||
AND LOWER(message->>'content') LIKE '%' || :content_key || '%'
|
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
|
|
||||||
postgres_content_clause = text(postgres_content_sql)
|
postgres_content_clause = text(postgres_content_sql)
|
||||||
|
# Also filter out chats with null bytes in title
|
||||||
|
query = query.filter(text("Chat.title::text NOT LIKE '%\\x00%'"))
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
Chat.title.ilike(bindparam("title_key")),
|
Chat.title.ilike(bindparam("title_key")),
|
||||||
postgres_content_clause,
|
postgres_content_clause,
|
||||||
)
|
).params(title_key=f"%{search_text}%", content_key=search_text)
|
||||||
).params(title_key=f"%{search_text}%", content_key=search_text.lower())
|
)
|
||||||
|
|
||||||
# Check if there are any tags to filter, it should have all the tags
|
# Check if there are any tags to filter, it should have all the tags
|
||||||
if "none" in tag_ids:
|
if "none" in tag_ids:
|
||||||
|
|
@ -1234,20 +1059,6 @@ class ChatTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def move_chats_by_user_id_and_folder_id(
|
|
||||||
self, user_id: str, folder_id: str, new_folder_id: Optional[str]
|
|
||||||
) -> bool:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update(
|
|
||||||
{"folder_id": new_folder_id}
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -1261,93 +1072,5 @@ class ChatTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def insert_chat_files(
|
|
||||||
self, chat_id: str, message_id: str, file_ids: list[str], user_id: str
|
|
||||||
) -> Optional[list[ChatFileModel]]:
|
|
||||||
if not file_ids:
|
|
||||||
return None
|
|
||||||
|
|
||||||
chat_message_file_ids = [
|
|
||||||
item.id
|
|
||||||
for item in self.get_chat_files_by_chat_id_and_message_id(
|
|
||||||
chat_id, message_id
|
|
||||||
)
|
|
||||||
]
|
|
||||||
# Remove duplicates and existing file_ids
|
|
||||||
file_ids = list(
|
|
||||||
set(
|
|
||||||
[
|
|
||||||
file_id
|
|
||||||
for file_id in file_ids
|
|
||||||
if file_id and file_id not in chat_message_file_ids
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not file_ids:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
now = int(time.time())
|
|
||||||
|
|
||||||
chat_files = [
|
|
||||||
ChatFileModel(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
message_id=message_id,
|
|
||||||
file_id=file_id,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
|
||||||
for file_id in file_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
results = [
|
|
||||||
ChatFile(**chat_file.model_dump()) for chat_file in chat_files
|
|
||||||
]
|
|
||||||
|
|
||||||
db.add_all(results)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return chat_files
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_chat_files_by_chat_id_and_message_id(
|
|
||||||
self, chat_id: str, message_id: str
|
|
||||||
) -> list[ChatFileModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
all_chat_files = (
|
|
||||||
db.query(ChatFile)
|
|
||||||
.filter_by(chat_id=chat_id, message_id=message_id)
|
|
||||||
.order_by(ChatFile.created_at.asc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files
|
|
||||||
]
|
|
||||||
|
|
||||||
def delete_chat_file(self, chat_id: str, file_id: str) -> bool:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete()
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
# Join Chat and ChatFile tables to get shared chats associated with the file_id
|
|
||||||
all_chats = (
|
|
||||||
db.query(Chat)
|
|
||||||
.join(ChatFile, Chat.id == ChatFile.chat_id)
|
|
||||||
.filter(ChatFile.file_id == file_id, Chat.share_id.isnot(None))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
|
||||||
|
|
||||||
|
|
||||||
Chats = ChatTable()
|
Chats = ChatTable()
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,14 @@ import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.users import User
|
from open_webui.models.chats import Chats
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -19,7 +21,7 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Feedback(Base):
|
class Feedback(Base):
|
||||||
__tablename__ = "feedback"
|
__tablename__ = "feedback"
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
version = Column(BigInteger, default=0)
|
version = Column(BigInteger, default=0)
|
||||||
type = Column(Text)
|
type = Column(Text)
|
||||||
|
|
@ -60,13 +62,6 @@ class FeedbackResponse(BaseModel):
|
||||||
updated_at: int
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
class FeedbackIdResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class RatingData(BaseModel):
|
class RatingData(BaseModel):
|
||||||
rating: Optional[str | int] = None
|
rating: Optional[str | int] = None
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None
|
||||||
|
|
@ -97,28 +92,6 @@ class FeedbackForm(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
email: str
|
|
||||||
role: str = "pending"
|
|
||||||
|
|
||||||
last_active_at: int # timestamp in epoch
|
|
||||||
updated_at: int # timestamp in epoch
|
|
||||||
created_at: int # timestamp in epoch
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackUserResponse(FeedbackResponse):
|
|
||||||
user: Optional[UserResponse] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackListResponse(BaseModel):
|
|
||||||
items: list[FeedbackUserResponse]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackTable:
|
class FeedbackTable:
|
||||||
def insert_new_feedback(
|
def insert_new_feedback(
|
||||||
self, user_id: str, form_data: FeedbackForm
|
self, user_id: str, form_data: FeedbackForm
|
||||||
|
|
@ -170,70 +143,6 @@ class FeedbackTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_feedback_items(
|
|
||||||
self, filter: dict = {}, skip: int = 0, limit: int = 30
|
|
||||||
) -> FeedbackListResponse:
|
|
||||||
with get_db() as db:
|
|
||||||
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
|
|
||||||
|
|
||||||
if filter:
|
|
||||||
order_by = filter.get("order_by")
|
|
||||||
direction = filter.get("direction")
|
|
||||||
|
|
||||||
if order_by == "username":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(User.name.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(User.name.desc())
|
|
||||||
elif order_by == "model_id":
|
|
||||||
# it's stored in feedback.data['model_id']
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(
|
|
||||||
Feedback.data["model_id"].as_string().asc()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query = query.order_by(
|
|
||||||
Feedback.data["model_id"].as_string().desc()
|
|
||||||
)
|
|
||||||
elif order_by == "rating":
|
|
||||||
# it's stored in feedback.data['rating']
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(
|
|
||||||
Feedback.data["rating"].as_string().asc()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query = query.order_by(
|
|
||||||
Feedback.data["rating"].as_string().desc()
|
|
||||||
)
|
|
||||||
elif order_by == "updated_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Feedback.updated_at.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Feedback.updated_at.desc())
|
|
||||||
|
|
||||||
else:
|
|
||||||
query = query.order_by(Feedback.created_at.desc())
|
|
||||||
|
|
||||||
# Count BEFORE pagination
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
if skip:
|
|
||||||
query = query.offset(skip)
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
items = query.all()
|
|
||||||
|
|
||||||
feedbacks = []
|
|
||||||
for feedback, user in items:
|
|
||||||
feedback_model = FeedbackModel.model_validate(feedback)
|
|
||||||
user_model = UserResponse.model_validate(user)
|
|
||||||
feedbacks.append(
|
|
||||||
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
|
|
||||||
)
|
|
||||||
|
|
||||||
return FeedbackListResponse(items=feedbacks, total=total)
|
|
||||||
|
|
||||||
def get_all_feedbacks(self) -> list[FeedbackModel]:
|
def get_all_feedbacks(self) -> list[FeedbackModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
return [
|
return [
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,12 @@ import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, JSONField, get_db
|
from open_webui.internal.db import Base, JSONField, get_db
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Files DB Schema
|
# Files DB Schema
|
||||||
|
|
@ -15,7 +17,7 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class File(Base):
|
class File(Base):
|
||||||
__tablename__ = "file"
|
__tablename__ = "file"
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
hash = Column(Text, nullable=True)
|
hash = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
|
@ -81,7 +83,7 @@ class FileModelResponse(BaseModel):
|
||||||
class FileMetadataResponse(BaseModel):
|
class FileMetadataResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
hash: Optional[str] = None
|
hash: Optional[str] = None
|
||||||
meta: Optional[dict] = None
|
meta: dict
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
|
@ -96,17 +98,6 @@ class FileForm(BaseModel):
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class FileUpdateForm(BaseModel):
|
|
||||||
hash: Optional[str] = None
|
|
||||||
data: Optional[dict] = None
|
|
||||||
meta: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FileListResponse(BaseModel):
|
|
||||||
items: list[FileModel]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class FilesTable:
|
class FilesTable:
|
||||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -213,35 +204,11 @@ class FilesTable:
|
||||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
for file in db.query(File).filter_by(user_id=user_id).all()
|
||||||
]
|
]
|
||||||
|
|
||||||
def update_file_by_id(
|
|
||||||
self, id: str, form_data: FileUpdateForm
|
|
||||||
) -> Optional[FileModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
try:
|
|
||||||
file = db.query(File).filter_by(id=id).first()
|
|
||||||
|
|
||||||
if form_data.hash is not None:
|
|
||||||
file.hash = form_data.hash
|
|
||||||
|
|
||||||
if form_data.data is not None:
|
|
||||||
file.data = {**(file.data if file.data else {}), **form_data.data}
|
|
||||||
|
|
||||||
if form_data.meta is not None:
|
|
||||||
file.meta = {**(file.meta if file.meta else {}), **form_data.meta}
|
|
||||||
|
|
||||||
file.updated_at = int(time.time())
|
|
||||||
db.commit()
|
|
||||||
return FileModel.model_validate(file)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error updating file completely by id: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
file.hash = hash
|
file.hash = hash
|
||||||
file.updated_at = int(time.time())
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return FileModel.model_validate(file)
|
return FileModel.model_validate(file)
|
||||||
|
|
@ -253,7 +220,6 @@ class FilesTable:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
file.data = {**(file.data if file.data else {}), **data}
|
file.data = {**(file.data if file.data else {}), **data}
|
||||||
file.updated_at = int(time.time())
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return FileModel.model_validate(file)
|
return FileModel.model_validate(file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -265,7 +231,6 @@ class FilesTable:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
file.meta = {**(file.meta if file.meta else {}), **meta}
|
file.meta = {**(file.meta if file.meta else {}), **meta}
|
||||||
file.updated_at = int(time.time())
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return FileModel.model_validate(file)
|
return FileModel.model_validate(file)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,11 @@ from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
|
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -21,7 +23,7 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Folder(Base):
|
class Folder(Base):
|
||||||
__tablename__ = "folder"
|
__tablename__ = "folder"
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
parent_id = Column(Text, nullable=True)
|
parent_id = Column(Text, nullable=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,12 @@ from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, JSONField, get_db
|
from open_webui.internal.db import Base, JSONField, get_db
|
||||||
from open_webui.models.users import Users, UserModel
|
from open_webui.models.users import Users, UserModel
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Functions DB Schema
|
# Functions DB Schema
|
||||||
|
|
@ -17,7 +19,7 @@ log = logging.getLogger(__name__)
|
||||||
class Function(Base):
|
class Function(Base):
|
||||||
__tablename__ = "function"
|
__tablename__ = "function"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
type = Column(Text)
|
type = Column(Text)
|
||||||
|
|
|
||||||
|
|
@ -5,26 +5,17 @@ from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.models.files import FileMetadataResponse
|
from open_webui.models.files import FileMetadataResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import (
|
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
|
||||||
BigInteger,
|
|
||||||
Column,
|
|
||||||
String,
|
|
||||||
Text,
|
|
||||||
JSON,
|
|
||||||
and_,
|
|
||||||
func,
|
|
||||||
ForeignKey,
|
|
||||||
cast,
|
|
||||||
or_,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# UserGroup DB Schema
|
# UserGroup DB Schema
|
||||||
|
|
@ -44,12 +35,14 @@ class Group(Base):
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
||||||
permissions = Column(JSON, nullable=True)
|
permissions = Column(JSON, nullable=True)
|
||||||
|
user_ids = Column(JSON, nullable=True)
|
||||||
|
|
||||||
created_at = Column(BigInteger)
|
created_at = Column(BigInteger)
|
||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
|
|
||||||
|
|
||||||
class GroupModel(BaseModel):
|
class GroupModel(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
id: str
|
id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
@ -60,64 +53,44 @@ class GroupModel(BaseModel):
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
permissions: Optional[dict] = None
|
permissions: Optional[dict] = None
|
||||||
|
user_ids: list[str] = []
|
||||||
|
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
class GroupMember(Base):
|
|
||||||
__tablename__ = "group_member"
|
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
|
||||||
group_id = Column(
|
|
||||||
Text,
|
|
||||||
ForeignKey("group.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
user_id = Column(Text, nullable=False)
|
|
||||||
created_at = Column(BigInteger, nullable=True)
|
|
||||||
updated_at = Column(BigInteger, nullable=True)
|
|
||||||
|
|
||||||
|
|
||||||
class GroupMemberModel(BaseModel):
|
|
||||||
id: str
|
|
||||||
group_id: str
|
|
||||||
user_id: str
|
|
||||||
created_at: Optional[int] = None # timestamp in epoch
|
|
||||||
updated_at: Optional[int] = None # timestamp in epoch
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Forms
|
# Forms
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
class GroupResponse(GroupModel):
|
class GroupResponse(BaseModel):
|
||||||
member_count: Optional[int] = None
|
id: str
|
||||||
|
user_id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
permissions: Optional[dict] = None
|
||||||
|
data: Optional[dict] = None
|
||||||
|
meta: Optional[dict] = None
|
||||||
|
user_ids: list[str] = []
|
||||||
|
created_at: int # timestamp in epoch
|
||||||
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
|
||||||
class GroupForm(BaseModel):
|
class GroupForm(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
permissions: Optional[dict] = None
|
permissions: Optional[dict] = None
|
||||||
data: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserIdsForm(BaseModel):
|
class UserIdsForm(BaseModel):
|
||||||
user_ids: Optional[list[str]] = None
|
user_ids: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class GroupUpdateForm(GroupForm):
|
class GroupUpdateForm(GroupForm, UserIdsForm):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GroupListResponse(BaseModel):
|
|
||||||
items: list[GroupResponse] = []
|
|
||||||
total: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class GroupTable:
|
class GroupTable:
|
||||||
def insert_new_group(
|
def insert_new_group(
|
||||||
self, user_id: str, form_data: GroupForm
|
self, user_id: str, form_data: GroupForm
|
||||||
|
|
@ -146,94 +119,24 @@ class GroupTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_all_groups(self) -> list[GroupModel]:
|
def get_groups(self) -> list[GroupModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
|
|
||||||
return [GroupModel.model_validate(group) for group in groups]
|
|
||||||
|
|
||||||
def get_groups(self, filter) -> list[GroupResponse]:
|
|
||||||
with get_db() as db:
|
|
||||||
query = db.query(Group)
|
|
||||||
|
|
||||||
if filter:
|
|
||||||
if "query" in filter:
|
|
||||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
|
||||||
if "member_id" in filter:
|
|
||||||
query = query.join(
|
|
||||||
GroupMember, GroupMember.group_id == Group.id
|
|
||||||
).filter(GroupMember.user_id == filter["member_id"])
|
|
||||||
|
|
||||||
if "share" in filter:
|
|
||||||
share_value = filter["share"]
|
|
||||||
json_share = Group.data["config"]["share"].as_boolean()
|
|
||||||
|
|
||||||
if share_value:
|
|
||||||
query = query.filter(
|
|
||||||
or_(
|
|
||||||
Group.data.is_(None),
|
|
||||||
json_share.is_(None),
|
|
||||||
json_share == True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query = query.filter(
|
|
||||||
and_(Group.data.isnot(None), json_share == False)
|
|
||||||
)
|
|
||||||
groups = query.order_by(Group.updated_at.desc()).all()
|
|
||||||
return [
|
return [
|
||||||
GroupResponse.model_validate(
|
GroupModel.model_validate(group)
|
||||||
{
|
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||||
**GroupModel.model_validate(group).model_dump(),
|
|
||||||
"member_count": self.get_group_member_count_by_id(group.id),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group in groups
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def search_groups(
|
|
||||||
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
|
|
||||||
) -> GroupListResponse:
|
|
||||||
with get_db() as db:
|
|
||||||
query = db.query(Group)
|
|
||||||
|
|
||||||
if filter:
|
|
||||||
if "query" in filter:
|
|
||||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
|
||||||
if "member_id" in filter:
|
|
||||||
query = query.join(
|
|
||||||
GroupMember, GroupMember.group_id == Group.id
|
|
||||||
).filter(GroupMember.user_id == filter["member_id"])
|
|
||||||
|
|
||||||
if "share" in filter:
|
|
||||||
# 'share' is stored in data JSON, support both sqlite and postgres
|
|
||||||
share_value = filter["share"]
|
|
||||||
print("Filtering by share:", share_value)
|
|
||||||
query = query.filter(
|
|
||||||
Group.data.op("->>")("share") == str(share_value)
|
|
||||||
)
|
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
query = query.order_by(Group.updated_at.desc())
|
|
||||||
groups = query.offset(skip).limit(limit).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"items": [
|
|
||||||
GroupResponse.model_validate(
|
|
||||||
**GroupModel.model_validate(group).model_dump(),
|
|
||||||
member_count=self.get_group_member_count_by_id(group.id),
|
|
||||||
)
|
|
||||||
for group in groups
|
|
||||||
],
|
|
||||||
"total": total,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
return [
|
return [
|
||||||
GroupModel.model_validate(group)
|
GroupModel.model_validate(group)
|
||||||
for group in db.query(Group)
|
for group in db.query(Group)
|
||||||
.join(GroupMember, GroupMember.group_id == Group.id)
|
.filter(
|
||||||
.filter(GroupMember.user_id == user_id)
|
func.json_array_length(Group.user_ids) > 0
|
||||||
|
) # Ensure array exists
|
||||||
|
.filter(
|
||||||
|
Group.user_ids.cast(String).like(f'%"{user_id}"%')
|
||||||
|
) # String-based check
|
||||||
.order_by(Group.updated_at.desc())
|
.order_by(Group.updated_at.desc())
|
||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
@ -246,63 +149,12 @@ class GroupTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
|
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
||||||
with get_db() as db:
|
group = self.get_group_by_id(id)
|
||||||
members = (
|
if group:
|
||||||
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
return group.user_ids
|
||||||
)
|
else:
|
||||||
|
return None
|
||||||
if not members:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return [m[0] for m in members]
|
|
||||||
|
|
||||||
def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]:
|
|
||||||
with get_db() as db:
|
|
||||||
members = (
|
|
||||||
db.query(GroupMember.group_id, GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id.in_(group_ids))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
group_user_ids: dict[str, list[str]] = {
|
|
||||||
group_id: [] for group_id in group_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
for group_id, user_id in members:
|
|
||||||
group_user_ids[group_id].append(user_id)
|
|
||||||
|
|
||||||
return group_user_ids
|
|
||||||
|
|
||||||
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
|
|
||||||
with get_db() as db:
|
|
||||||
# Delete existing members
|
|
||||||
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
|
||||||
|
|
||||||
# Insert new members
|
|
||||||
now = int(time.time())
|
|
||||||
new_members = [
|
|
||||||
GroupMember(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
group_id=group_id,
|
|
||||||
user_id=user_id,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
|
||||||
for user_id in user_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
db.add_all(new_members)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
def get_group_member_count_by_id(self, id: str) -> int:
|
|
||||||
with get_db() as db:
|
|
||||||
count = (
|
|
||||||
db.query(func.count(GroupMember.user_id))
|
|
||||||
.filter(GroupMember.group_id == id)
|
|
||||||
.scalar()
|
|
||||||
)
|
|
||||||
return count if count else 0
|
|
||||||
|
|
||||||
def update_group_by_id(
|
def update_group_by_id(
|
||||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||||
|
|
@ -343,29 +195,20 @@ class GroupTable:
|
||||||
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
# Find all groups the user belongs to
|
groups = self.get_groups_by_member_id(user_id)
|
||||||
groups = (
|
|
||||||
db.query(Group)
|
|
||||||
.join(GroupMember, GroupMember.group_id == Group.id)
|
|
||||||
.filter(GroupMember.user_id == user_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove the user from each group
|
|
||||||
for group in groups:
|
for group in groups:
|
||||||
db.query(GroupMember).filter(
|
group.user_ids.remove(user_id)
|
||||||
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
|
||||||
).delete()
|
|
||||||
|
|
||||||
db.query(Group).filter_by(id=group.id).update(
|
db.query(Group).filter_by(id=group.id).update(
|
||||||
{"updated_at": int(time.time())}
|
{
|
||||||
|
"user_ids": group.user_ids,
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
db.rollback()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def create_groups_by_group_names(
|
def create_groups_by_group_names(
|
||||||
|
|
@ -373,7 +216,7 @@ class GroupTable:
|
||||||
) -> list[GroupModel]:
|
) -> list[GroupModel]:
|
||||||
|
|
||||||
# check for existing groups
|
# check for existing groups
|
||||||
existing_groups = self.get_all_groups()
|
existing_groups = self.get_groups()
|
||||||
existing_group_names = {group.name for group in existing_groups}
|
existing_group_names = {group.name for group in existing_groups}
|
||||||
|
|
||||||
new_groups = []
|
new_groups = []
|
||||||
|
|
@ -403,61 +246,37 @@ class GroupTable:
|
||||||
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
now = int(time.time())
|
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||||
|
group_ids = [group.id for group in groups]
|
||||||
|
|
||||||
# 1. Groups that SHOULD contain the user
|
# Remove user from groups not in the new list
|
||||||
target_groups = (
|
existing_groups = self.get_groups_by_member_id(user_id)
|
||||||
db.query(Group).filter(Group.name.in_(group_names)).all()
|
|
||||||
)
|
|
||||||
target_group_ids = {g.id for g in target_groups}
|
|
||||||
|
|
||||||
# 2. Groups the user is CURRENTLY in
|
for group in existing_groups:
|
||||||
existing_group_ids = {
|
if group.id not in group_ids:
|
||||||
g.id
|
group.user_ids.remove(user_id)
|
||||||
for g in db.query(Group)
|
db.query(Group).filter_by(id=group.id).update(
|
||||||
.join(GroupMember, GroupMember.group_id == Group.id)
|
{
|
||||||
.filter(GroupMember.user_id == user_id)
|
"user_ids": group.user_ids,
|
||||||
.all()
|
"updated_at": int(time.time()),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 3. Determine adds + removals
|
|
||||||
groups_to_add = target_group_ids - existing_group_ids
|
|
||||||
groups_to_remove = existing_group_ids - target_group_ids
|
|
||||||
|
|
||||||
# 4. Remove in one bulk delete
|
|
||||||
if groups_to_remove:
|
|
||||||
db.query(GroupMember).filter(
|
|
||||||
GroupMember.user_id == user_id,
|
|
||||||
GroupMember.group_id.in_(groups_to_remove),
|
|
||||||
).delete(synchronize_session=False)
|
|
||||||
|
|
||||||
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
|
||||||
{"updated_at": now}, synchronize_session=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Bulk insert missing memberships
|
|
||||||
for group_id in groups_to_add:
|
|
||||||
db.add(
|
|
||||||
GroupMember(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
group_id=group_id,
|
|
||||||
user_id=user_id,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if groups_to_add:
|
# Add user to new groups
|
||||||
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
for group in groups:
|
||||||
{"updated_at": now}, synchronize_session=False
|
if user_id not in group.user_ids:
|
||||||
)
|
group.user_ids.append(user_id)
|
||||||
|
db.query(Group).filter_by(id=group.id).update(
|
||||||
|
{
|
||||||
|
"user_ids": group.user_ids,
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
db.rollback()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_users_to_group(
|
def add_users_to_group(
|
||||||
|
|
@ -469,31 +288,21 @@ class GroupTable:
|
||||||
if not group:
|
if not group:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
now = int(time.time())
|
group_user_ids = group.user_ids
|
||||||
|
if not group_user_ids or not isinstance(group_user_ids, list):
|
||||||
|
group_user_ids = []
|
||||||
|
|
||||||
for user_id in user_ids or []:
|
group_user_ids = list(set(group_user_ids)) # Deduplicate
|
||||||
try:
|
|
||||||
db.add(
|
|
||||||
GroupMember(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
group_id=id,
|
|
||||||
user_id=user_id,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
db.flush() # Detect unique constraint violation early
|
|
||||||
except Exception:
|
|
||||||
db.rollback() # Clear failed INSERT
|
|
||||||
db.begin() # Start a new transaction
|
|
||||||
continue # Duplicate → ignore
|
|
||||||
|
|
||||||
group.updated_at = now
|
for user_id in user_ids:
|
||||||
|
if user_id not in group_user_ids:
|
||||||
|
group_user_ids.append(user_id)
|
||||||
|
|
||||||
|
group.user_ids = group_user_ids
|
||||||
|
group.updated_at = int(time.time())
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(group)
|
db.refresh(group)
|
||||||
|
|
||||||
return GroupModel.model_validate(group)
|
return GroupModel.model_validate(group)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
return None
|
return None
|
||||||
|
|
@ -507,22 +316,23 @@ class GroupTable:
|
||||||
if not group:
|
if not group:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not user_ids:
|
group_user_ids = group.user_ids
|
||||||
|
|
||||||
|
if not group_user_ids or not isinstance(group_user_ids, list):
|
||||||
return GroupModel.model_validate(group)
|
return GroupModel.model_validate(group)
|
||||||
|
|
||||||
# Remove each user from group_member
|
group_user_ids = list(set(group_user_ids)) # Deduplicate
|
||||||
for user_id in user_ids:
|
|
||||||
db.query(GroupMember).filter(
|
|
||||||
GroupMember.group_id == id, GroupMember.user_id == user_id
|
|
||||||
).delete()
|
|
||||||
|
|
||||||
# Update group timestamp
|
for user_id in user_ids:
|
||||||
|
if user_id in group_user_ids:
|
||||||
|
group_user_ids.remove(user_id)
|
||||||
|
|
||||||
|
group.user_ids = group_user_ids
|
||||||
group.updated_at = int(time.time())
|
group.updated_at = int(time.time())
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(group)
|
db.refresh(group)
|
||||||
return GroupModel.model_validate(group)
|
return GroupModel.model_validate(group)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -5,34 +5,20 @@ from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.models.files import (
|
from open_webui.models.files import FileMetadataResponse
|
||||||
File,
|
|
||||||
FileModel,
|
|
||||||
FileMetadataResponse,
|
|
||||||
FileModelResponse,
|
|
||||||
)
|
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import (
|
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||||
BigInteger,
|
|
||||||
Column,
|
|
||||||
ForeignKey,
|
|
||||||
String,
|
|
||||||
Text,
|
|
||||||
JSON,
|
|
||||||
UniqueConstraint,
|
|
||||||
or_,
|
|
||||||
)
|
|
||||||
|
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
from open_webui.utils.db.access_control import has_permission
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Knowledge DB Schema
|
# Knowledge DB Schema
|
||||||
|
|
@ -48,7 +34,9 @@ class Knowledge(Base):
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
description = Column(Text)
|
description = Column(Text)
|
||||||
|
|
||||||
|
data = Column(JSON, nullable=True)
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
||||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||||
# Defines access control rules for this entry.
|
# Defines access control rules for this entry.
|
||||||
# - `None`: Public access, available to all users with the "user" role.
|
# - `None`: Public access, available to all users with the "user" role.
|
||||||
|
|
@ -79,6 +67,7 @@ class KnowledgeModel(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
@ -87,42 +76,11 @@ class KnowledgeModel(BaseModel):
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeFile(Base):
|
|
||||||
__tablename__ = "knowledge_file"
|
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
|
||||||
|
|
||||||
knowledge_id = Column(
|
|
||||||
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
|
||||||
user_id = Column(Text, nullable=False)
|
|
||||||
|
|
||||||
created_at = Column(BigInteger, nullable=False)
|
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint(
|
|
||||||
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeFileModel(BaseModel):
|
|
||||||
id: str
|
|
||||||
knowledge_id: str
|
|
||||||
file_id: str
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
created_at: int # timestamp in epoch
|
|
||||||
updated_at: int # timestamp in epoch
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Forms
|
# Forms
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeUserModel(KnowledgeModel):
|
class KnowledgeUserModel(KnowledgeModel):
|
||||||
user: Optional[UserResponse] = None
|
user: Optional[UserResponse] = None
|
||||||
|
|
||||||
|
|
@ -132,29 +90,16 @@ class KnowledgeResponse(KnowledgeModel):
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeUserResponse(KnowledgeUserModel):
|
class KnowledgeUserResponse(KnowledgeUserModel):
|
||||||
pass
|
files: Optional[list[FileMetadataResponse | dict]] = None
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeForm(BaseModel):
|
class KnowledgeForm(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
data: Optional[dict] = None
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class FileUserResponse(FileModelResponse):
|
|
||||||
user: Optional[UserResponse] = None
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeListResponse(BaseModel):
|
|
||||||
items: list[KnowledgeUserModel]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeFileListResponse(BaseModel):
|
|
||||||
items: list[FileUserResponse]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeTable:
|
class KnowledgeTable:
|
||||||
def insert_new_knowledge(
|
def insert_new_knowledge(
|
||||||
self, user_id: str, form_data: KnowledgeForm
|
self, user_id: str, form_data: KnowledgeForm
|
||||||
|
|
@ -182,13 +127,12 @@ class KnowledgeTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_knowledge_bases(
|
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
|
||||||
self, skip: int = 0, limit: int = 30
|
|
||||||
) -> list[KnowledgeUserModel]:
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
all_knowledge = (
|
all_knowledge = (
|
||||||
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
|
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
|
||||||
|
|
||||||
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||||
|
|
@ -207,126 +151,6 @@ class KnowledgeTable:
|
||||||
)
|
)
|
||||||
return knowledge_bases
|
return knowledge_bases
|
||||||
|
|
||||||
def search_knowledge_bases(
|
|
||||||
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30
|
|
||||||
) -> KnowledgeListResponse:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
query = db.query(Knowledge, User).outerjoin(
|
|
||||||
User, User.id == Knowledge.user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if filter:
|
|
||||||
query_key = filter.get("query")
|
|
||||||
if query_key:
|
|
||||||
query = query.filter(
|
|
||||||
or_(
|
|
||||||
Knowledge.name.ilike(f"%{query_key}%"),
|
|
||||||
Knowledge.description.ilike(f"%{query_key}%"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
|
||||||
if view_option == "created":
|
|
||||||
query = query.filter(Knowledge.user_id == user_id)
|
|
||||||
elif view_option == "shared":
|
|
||||||
query = query.filter(Knowledge.user_id != user_id)
|
|
||||||
|
|
||||||
query = has_permission(db, Knowledge, query, filter)
|
|
||||||
|
|
||||||
query = query.order_by(Knowledge.updated_at.desc())
|
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
if skip:
|
|
||||||
query = query.offset(skip)
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
items = query.all()
|
|
||||||
|
|
||||||
knowledge_bases = []
|
|
||||||
for knowledge_base, user in items:
|
|
||||||
knowledge_bases.append(
|
|
||||||
KnowledgeUserModel.model_validate(
|
|
||||||
{
|
|
||||||
**KnowledgeModel.model_validate(
|
|
||||||
knowledge_base
|
|
||||||
).model_dump(),
|
|
||||||
"user": (
|
|
||||||
UserModel.model_validate(user).model_dump()
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return KnowledgeListResponse(items=knowledge_bases, total=total)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return KnowledgeListResponse(items=[], total=0)
|
|
||||||
|
|
||||||
def search_knowledge_files(
|
|
||||||
self, filter: dict, skip: int = 0, limit: int = 30
|
|
||||||
) -> KnowledgeFileListResponse:
|
|
||||||
"""
|
|
||||||
Scalable version: search files across all knowledge bases the user has
|
|
||||||
READ access to, without loading all KBs or using large IN() lists.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
# Base query: join Knowledge → KnowledgeFile → File
|
|
||||||
query = (
|
|
||||||
db.query(File, User)
|
|
||||||
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
|
|
||||||
.join(Knowledge, KnowledgeFile.knowledge_id == Knowledge.id)
|
|
||||||
.outerjoin(User, User.id == KnowledgeFile.user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply access-control directly to the joined query
|
|
||||||
# This makes the database handle filtering, even with 10k+ KBs
|
|
||||||
query = has_permission(db, Knowledge, query, filter)
|
|
||||||
|
|
||||||
# Apply filename search
|
|
||||||
if filter:
|
|
||||||
q = filter.get("query")
|
|
||||||
if q:
|
|
||||||
query = query.filter(File.filename.ilike(f"%{q}%"))
|
|
||||||
|
|
||||||
# Order by file changes
|
|
||||||
query = query.order_by(File.updated_at.desc())
|
|
||||||
|
|
||||||
# Count before pagination
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
if skip:
|
|
||||||
query = query.offset(skip)
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
rows = query.all()
|
|
||||||
|
|
||||||
items = []
|
|
||||||
for file, user in rows:
|
|
||||||
items.append(
|
|
||||||
FileUserResponse(
|
|
||||||
**FileModel.model_validate(file).model_dump(),
|
|
||||||
user=(
|
|
||||||
UserResponse(
|
|
||||||
**UserModel.model_validate(user).model_dump()
|
|
||||||
)
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return KnowledgeFileListResponse(items=items, total=total)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print("search_knowledge_files error:", e)
|
|
||||||
return KnowledgeFileListResponse(items=[], total=0)
|
|
||||||
|
|
||||||
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
|
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
|
||||||
knowledge = self.get_knowledge_by_id(id)
|
knowledge = self.get_knowledge_by_id(id)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
|
|
@ -358,197 +182,6 @@ class KnowledgeTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_knowledge_by_id_and_user_id(
|
|
||||||
self, id: str, user_id: str
|
|
||||||
) -> Optional[KnowledgeModel]:
|
|
||||||
knowledge = self.get_knowledge_by_id(id)
|
|
||||||
if not knowledge:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if knowledge.user_id == user_id:
|
|
||||||
return knowledge
|
|
||||||
|
|
||||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
|
||||||
if has_access(user_id, "write", knowledge.access_control, user_group_ids):
|
|
||||||
return knowledge
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
knowledges = (
|
|
||||||
db.query(Knowledge)
|
|
||||||
.join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id)
|
|
||||||
.filter(KnowledgeFile.file_id == file_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
KnowledgeModel.model_validate(knowledge) for knowledge in knowledges
|
|
||||||
]
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_files_by_id(
|
|
||||||
self,
|
|
||||||
knowledge_id: str,
|
|
||||||
user_id: str,
|
|
||||||
filter: dict,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 30,
|
|
||||||
) -> KnowledgeFileListResponse:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
query = (
|
|
||||||
db.query(File, User)
|
|
||||||
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
|
|
||||||
.outerjoin(User, User.id == KnowledgeFile.user_id)
|
|
||||||
.filter(KnowledgeFile.knowledge_id == knowledge_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
if filter:
|
|
||||||
query_key = filter.get("query")
|
|
||||||
if query_key:
|
|
||||||
query = query.filter(or_(File.filename.ilike(f"%{query_key}%")))
|
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
|
||||||
if view_option == "created":
|
|
||||||
query = query.filter(KnowledgeFile.user_id == user_id)
|
|
||||||
elif view_option == "shared":
|
|
||||||
query = query.filter(KnowledgeFile.user_id != user_id)
|
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
|
||||||
direction = filter.get("direction")
|
|
||||||
|
|
||||||
if order_by == "name":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(File.filename.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(File.filename.desc())
|
|
||||||
elif order_by == "created_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(File.created_at.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(File.created_at.desc())
|
|
||||||
elif order_by == "updated_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(File.updated_at.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(File.updated_at.desc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(File.updated_at.desc())
|
|
||||||
|
|
||||||
else:
|
|
||||||
query = query.order_by(File.updated_at.desc())
|
|
||||||
|
|
||||||
# Count BEFORE pagination
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
if skip:
|
|
||||||
query = query.offset(skip)
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
items = query.all()
|
|
||||||
|
|
||||||
files = []
|
|
||||||
for file, user in items:
|
|
||||||
files.append(
|
|
||||||
FileUserResponse(
|
|
||||||
**FileModel.model_validate(file).model_dump(),
|
|
||||||
user=(
|
|
||||||
UserResponse(
|
|
||||||
**UserModel.model_validate(user).model_dump()
|
|
||||||
)
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return KnowledgeFileListResponse(items=files, total=total)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return KnowledgeFileListResponse(items=[], total=0)
|
|
||||||
|
|
||||||
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
files = (
|
|
||||||
db.query(File)
|
|
||||||
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
|
|
||||||
.filter(KnowledgeFile.knowledge_id == knowledge_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [FileModel.model_validate(file) for file in files]
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
files = self.get_files_by_id(knowledge_id)
|
|
||||||
return [FileMetadataResponse(**file.model_dump()) for file in files]
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def add_file_to_knowledge_by_id(
|
|
||||||
self, knowledge_id: str, file_id: str, user_id: str
|
|
||||||
) -> Optional[KnowledgeFileModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
knowledge_file = KnowledgeFileModel(
|
|
||||||
**{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"knowledge_id": knowledge_id,
|
|
||||||
"file_id": file_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"created_at": int(time.time()),
|
|
||||||
"updated_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = KnowledgeFile(**knowledge_file.model_dump())
|
|
||||||
db.add(result)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(result)
|
|
||||||
if result:
|
|
||||||
return KnowledgeFileModel.model_validate(result)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
db.query(KnowledgeFile).filter_by(
|
|
||||||
knowledge_id=knowledge_id, file_id=file_id
|
|
||||||
).delete()
|
|
||||||
db.commit()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
# Delete all knowledge_file entries for this knowledge_id
|
|
||||||
db.query(KnowledgeFile).filter_by(knowledge_id=id).delete()
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
# Update the knowledge entry's updated_at timestamp
|
|
||||||
db.query(Knowledge).filter_by(id=id).update(
|
|
||||||
{
|
|
||||||
"updated_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return self.get_knowledge_by_id(id=id)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_knowledge_by_id(
|
def update_knowledge_by_id(
|
||||||
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
||||||
) -> Optional[KnowledgeModel]:
|
) -> Optional[KnowledgeModel]:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from sqlalchemy import BigInteger, Column, String, Text
|
||||||
class Memory(Base):
|
class Memory(Base):
|
||||||
__tablename__ = "memory"
|
__tablename__ = "memory"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,10 @@ from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.tags import TagModel, Tag, Tags
|
from open_webui.models.tags import TagModel, Tag, Tags
|
||||||
from open_webui.models.users import Users, User, UserNameResponse
|
from open_webui.models.users import Users, UserNameResponse
|
||||||
from open_webui.models.channels import Channels, ChannelMember
|
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||||
from sqlalchemy import or_, func, select, and_, text
|
from sqlalchemy import or_, func, select, and_, text
|
||||||
from sqlalchemy.sql import exists
|
from sqlalchemy.sql import exists
|
||||||
|
|
@ -21,7 +20,7 @@ from sqlalchemy.sql import exists
|
||||||
|
|
||||||
class MessageReaction(Base):
|
class MessageReaction(Base):
|
||||||
__tablename__ = "message_reaction"
|
__tablename__ = "message_reaction"
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
message_id = Column(Text)
|
message_id = Column(Text)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
|
|
@ -40,7 +39,7 @@ class MessageReactionModel(BaseModel):
|
||||||
|
|
||||||
class Message(Base):
|
class Message(Base):
|
||||||
__tablename__ = "message"
|
__tablename__ = "message"
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
|
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
channel_id = Column(Text, nullable=True)
|
channel_id = Column(Text, nullable=True)
|
||||||
|
|
@ -48,11 +47,6 @@ class Message(Base):
|
||||||
reply_to_id = Column(Text, nullable=True)
|
reply_to_id = Column(Text, nullable=True)
|
||||||
parent_id = Column(Text, nullable=True)
|
parent_id = Column(Text, nullable=True)
|
||||||
|
|
||||||
# Pins
|
|
||||||
is_pinned = Column(Boolean, nullable=False, default=False)
|
|
||||||
pinned_at = Column(BigInteger, nullable=True)
|
|
||||||
pinned_by = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
data = Column(JSON, nullable=True)
|
data = Column(JSON, nullable=True)
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
@ -71,17 +65,12 @@ class MessageModel(BaseModel):
|
||||||
reply_to_id: Optional[str] = None
|
reply_to_id: Optional[str] = None
|
||||||
parent_id: Optional[str] = None
|
parent_id: Optional[str] = None
|
||||||
|
|
||||||
# Pins
|
|
||||||
is_pinned: bool = False
|
|
||||||
pinned_by: Optional[str] = None
|
|
||||||
pinned_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
created_at: int # timestamp in epoch (time_ns)
|
created_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch (time_ns)
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -90,7 +79,6 @@ class MessageModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MessageForm(BaseModel):
|
class MessageForm(BaseModel):
|
||||||
temp_id: Optional[str] = None
|
|
||||||
content: str
|
content: str
|
||||||
reply_to_id: Optional[str] = None
|
reply_to_id: Optional[str] = None
|
||||||
parent_id: Optional[str] = None
|
parent_id: Optional[str] = None
|
||||||
|
|
@ -100,7 +88,7 @@ class MessageForm(BaseModel):
|
||||||
|
|
||||||
class Reactions(BaseModel):
|
class Reactions(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
users: list[dict]
|
user_ids: list[str]
|
||||||
count: int
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,25 +96,8 @@ class MessageUserResponse(MessageModel):
|
||||||
user: Optional[UserNameResponse] = None
|
user: Optional[UserNameResponse] = None
|
||||||
|
|
||||||
|
|
||||||
class MessageUserSlimResponse(MessageUserResponse):
|
|
||||||
data: bool | None = None
|
|
||||||
|
|
||||||
@field_validator("data", mode="before")
|
|
||||||
def convert_data_to_bool(cls, v):
|
|
||||||
# No data or not a dict → False
|
|
||||||
if not isinstance(v, dict):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# True if ANY value in the dict is non-empty
|
|
||||||
return any(bool(val) for val in v.values())
|
|
||||||
|
|
||||||
|
|
||||||
class MessageReplyToResponse(MessageUserResponse):
|
class MessageReplyToResponse(MessageUserResponse):
|
||||||
reply_to_message: Optional[MessageUserSlimResponse] = None
|
reply_to_message: Optional[MessageUserResponse] = None
|
||||||
|
|
||||||
|
|
||||||
class MessageWithReactionsResponse(MessageUserSlimResponse):
|
|
||||||
reactions: list[Reactions]
|
|
||||||
|
|
||||||
|
|
||||||
class MessageResponse(MessageReplyToResponse):
|
class MessageResponse(MessageReplyToResponse):
|
||||||
|
|
@ -140,11 +111,9 @@ class MessageTable:
|
||||||
self, form_data: MessageForm, channel_id: str, user_id: str
|
self, form_data: MessageForm, channel_id: str, user_id: str
|
||||||
) -> Optional[MessageModel]:
|
) -> Optional[MessageModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
channel_member = Channels.join_channel(channel_id, user_id)
|
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
ts = int(time.time_ns())
|
|
||||||
|
|
||||||
|
ts = int(time.time_ns())
|
||||||
message = MessageModel(
|
message = MessageModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
|
|
@ -152,9 +121,6 @@ class MessageTable:
|
||||||
"channel_id": channel_id,
|
"channel_id": channel_id,
|
||||||
"reply_to_id": form_data.reply_to_id,
|
"reply_to_id": form_data.reply_to_id,
|
||||||
"parent_id": form_data.parent_id,
|
"parent_id": form_data.parent_id,
|
||||||
"is_pinned": False,
|
|
||||||
"pinned_at": None,
|
|
||||||
"pinned_by": None,
|
|
||||||
"content": form_data.content,
|
"content": form_data.content,
|
||||||
"data": form_data.data,
|
"data": form_data.data,
|
||||||
"meta": form_data.meta,
|
"meta": form_data.meta,
|
||||||
|
|
@ -162,8 +128,8 @@ class MessageTable:
|
||||||
"updated_at": ts,
|
"updated_at": ts,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = Message(**message.model_dump())
|
|
||||||
|
|
||||||
|
result = Message(**message.model_dump())
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
|
|
@ -314,30 +280,6 @@ class MessageTable:
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
message = (
|
|
||||||
db.query(Message)
|
|
||||||
.filter_by(channel_id=channel_id)
|
|
||||||
.order_by(Message.created_at.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return MessageModel.model_validate(message) if message else None
|
|
||||||
|
|
||||||
def get_pinned_messages_by_channel_id(
|
|
||||||
self, channel_id: str, skip: int = 0, limit: int = 50
|
|
||||||
) -> list[MessageModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
all_messages = (
|
|
||||||
db.query(Message)
|
|
||||||
.filter_by(channel_id=channel_id, is_pinned=True)
|
|
||||||
.order_by(Message.pinned_at.desc())
|
|
||||||
.offset(skip)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [MessageModel.model_validate(message) for message in all_messages]
|
|
||||||
|
|
||||||
def update_message_by_id(
|
def update_message_by_id(
|
||||||
self, id: str, form_data: MessageForm
|
self, id: str, form_data: MessageForm
|
||||||
) -> Optional[MessageModel]:
|
) -> Optional[MessageModel]:
|
||||||
|
|
@ -357,44 +299,10 @@ class MessageTable:
|
||||||
db.refresh(message)
|
db.refresh(message)
|
||||||
return MessageModel.model_validate(message) if message else None
|
return MessageModel.model_validate(message) if message else None
|
||||||
|
|
||||||
def update_is_pinned_by_id(
|
|
||||||
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None
|
|
||||||
) -> Optional[MessageModel]:
|
|
||||||
with get_db() as db:
|
|
||||||
message = db.get(Message, id)
|
|
||||||
message.is_pinned = is_pinned
|
|
||||||
message.pinned_at = int(time.time_ns()) if is_pinned else None
|
|
||||||
message.pinned_by = pinned_by if is_pinned else None
|
|
||||||
db.commit()
|
|
||||||
db.refresh(message)
|
|
||||||
return MessageModel.model_validate(message) if message else None
|
|
||||||
|
|
||||||
def get_unread_message_count(
|
|
||||||
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None
|
|
||||||
) -> int:
|
|
||||||
with get_db() as db:
|
|
||||||
query = db.query(Message).filter(
|
|
||||||
Message.channel_id == channel_id,
|
|
||||||
Message.parent_id == None, # only count top-level messages
|
|
||||||
Message.created_at > (last_read_at if last_read_at else 0),
|
|
||||||
)
|
|
||||||
if user_id:
|
|
||||||
query = query.filter(Message.user_id != user_id)
|
|
||||||
return query.count()
|
|
||||||
|
|
||||||
def add_reaction_to_message(
|
def add_reaction_to_message(
|
||||||
self, id: str, user_id: str, name: str
|
self, id: str, user_id: str, name: str
|
||||||
) -> Optional[MessageReactionModel]:
|
) -> Optional[MessageReactionModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# check for existing reaction
|
|
||||||
existing_reaction = (
|
|
||||||
db.query(MessageReaction)
|
|
||||||
.filter_by(message_id=id, user_id=user_id, name=name)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing_reaction:
|
|
||||||
return MessageReactionModel.model_validate(existing_reaction)
|
|
||||||
|
|
||||||
reaction_id = str(uuid.uuid4())
|
reaction_id = str(uuid.uuid4())
|
||||||
reaction = MessageReactionModel(
|
reaction = MessageReactionModel(
|
||||||
id=reaction_id,
|
id=reaction_id,
|
||||||
|
|
@ -411,30 +319,17 @@ class MessageTable:
|
||||||
|
|
||||||
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
|
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# JOIN User so all user info is fetched in one query
|
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
|
||||||
results = (
|
|
||||||
db.query(MessageReaction, User)
|
|
||||||
.join(User, MessageReaction.user_id == User.id)
|
|
||||||
.filter(MessageReaction.message_id == id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
reactions = {}
|
reactions = {}
|
||||||
|
for reaction in all_reactions:
|
||||||
for reaction, user in results:
|
|
||||||
if reaction.name not in reactions:
|
if reaction.name not in reactions:
|
||||||
reactions[reaction.name] = {
|
reactions[reaction.name] = {
|
||||||
"name": reaction.name,
|
"name": reaction.name,
|
||||||
"users": [],
|
"user_ids": [],
|
||||||
"count": 0,
|
"count": 0,
|
||||||
}
|
}
|
||||||
|
reactions[reaction.name]["user_ids"].append(reaction.user_id)
|
||||||
reactions[reaction.name]["users"].append(
|
|
||||||
{
|
|
||||||
"id": user.id,
|
|
||||||
"name": user.name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
reactions[reaction.name]["count"] += 1
|
reactions[reaction.name]["count"] += 1
|
||||||
|
|
||||||
return [Reactions(**reaction) for reaction in reactions.values()]
|
return [Reactions(**reaction) for reaction in reactions.values()]
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,16 @@ import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, JSONField, get_db
|
from open_webui.internal.db import Base, JSONField, get_db
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from sqlalchemy import String, cast, or_, and_, func
|
from sqlalchemy import or_, and_, func
|
||||||
from sqlalchemy.dialects import postgresql, sqlite
|
from sqlalchemy.dialects import postgresql, sqlite
|
||||||
|
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,6 +20,7 @@ from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -53,7 +53,7 @@ class ModelMeta(BaseModel):
|
||||||
class Model(Base):
|
class Model(Base):
|
||||||
__tablename__ = "model"
|
__tablename__ = "model"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
"""
|
"""
|
||||||
The model's id as used in the API. If set to an existing model, it will override the model.
|
The model's id as used in the API. If set to an existing model, it will override the model.
|
||||||
"""
|
"""
|
||||||
|
|
@ -133,11 +133,6 @@ class ModelResponse(ModelModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelListResponse(BaseModel):
|
|
||||||
items: list[ModelUserResponse]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class ModelForm(BaseModel):
|
class ModelForm(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
base_model_id: Optional[str] = None
|
base_model_id: Optional[str] = None
|
||||||
|
|
@ -220,135 +215,6 @@ class ModelsTable:
|
||||||
or has_access(user_id, permission, model.access_control, user_group_ids)
|
or has_access(user_id, permission, model.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
|
||||||
group_ids = filter.get("group_ids", [])
|
|
||||||
user_id = filter.get("user_id")
|
|
||||||
|
|
||||||
dialect_name = db.bind.dialect.name
|
|
||||||
|
|
||||||
# Public access
|
|
||||||
conditions = []
|
|
||||||
if group_ids or user_id:
|
|
||||||
conditions.extend(
|
|
||||||
[
|
|
||||||
Model.access_control.is_(None),
|
|
||||||
cast(Model.access_control, String) == "null",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# User-level permission
|
|
||||||
if user_id:
|
|
||||||
conditions.append(Model.user_id == user_id)
|
|
||||||
|
|
||||||
# Group-level permission
|
|
||||||
if group_ids:
|
|
||||||
group_conditions = []
|
|
||||||
for gid in group_ids:
|
|
||||||
if dialect_name == "sqlite":
|
|
||||||
group_conditions.append(
|
|
||||||
Model.access_control[permission]["group_ids"].contains([gid])
|
|
||||||
)
|
|
||||||
elif dialect_name == "postgresql":
|
|
||||||
group_conditions.append(
|
|
||||||
cast(
|
|
||||||
Model.access_control[permission]["group_ids"],
|
|
||||||
JSONB,
|
|
||||||
).contains([gid])
|
|
||||||
)
|
|
||||||
conditions.append(or_(*group_conditions))
|
|
||||||
|
|
||||||
if conditions:
|
|
||||||
query = query.filter(or_(*conditions))
|
|
||||||
|
|
||||||
return query
|
|
||||||
|
|
||||||
def search_models(
|
|
||||||
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
|
|
||||||
) -> ModelListResponse:
|
|
||||||
with get_db() as db:
|
|
||||||
# Join GroupMember so we can order by group_id when requested
|
|
||||||
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
|
|
||||||
query = query.filter(Model.base_model_id != None)
|
|
||||||
|
|
||||||
if filter:
|
|
||||||
query_key = filter.get("query")
|
|
||||||
if query_key:
|
|
||||||
query = query.filter(
|
|
||||||
or_(
|
|
||||||
Model.name.ilike(f"%{query_key}%"),
|
|
||||||
Model.base_model_id.ilike(f"%{query_key}%"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
|
||||||
if view_option == "created":
|
|
||||||
query = query.filter(Model.user_id == user_id)
|
|
||||||
elif view_option == "shared":
|
|
||||||
query = query.filter(Model.user_id != user_id)
|
|
||||||
|
|
||||||
# Apply access control filtering
|
|
||||||
query = self._has_permission(
|
|
||||||
db,
|
|
||||||
query,
|
|
||||||
filter,
|
|
||||||
permission="write",
|
|
||||||
)
|
|
||||||
|
|
||||||
tag = filter.get("tag")
|
|
||||||
if tag:
|
|
||||||
# TODO: This is a simple implementation and should be improved for performance
|
|
||||||
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
|
|
||||||
meta_text = func.lower(cast(Model.meta, String))
|
|
||||||
|
|
||||||
query = query.filter(meta_text.like(like_pattern))
|
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
|
||||||
direction = filter.get("direction")
|
|
||||||
|
|
||||||
if order_by == "name":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Model.name.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Model.name.desc())
|
|
||||||
elif order_by == "created_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Model.created_at.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Model.created_at.desc())
|
|
||||||
elif order_by == "updated_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Model.updated_at.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Model.updated_at.desc())
|
|
||||||
|
|
||||||
else:
|
|
||||||
query = query.order_by(Model.created_at.desc())
|
|
||||||
|
|
||||||
# Count BEFORE pagination
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
if skip:
|
|
||||||
query = query.offset(skip)
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
items = query.all()
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for model, user in items:
|
|
||||||
models.append(
|
|
||||||
ModelUserResponse(
|
|
||||||
**ModelModel.model_validate(model).model_dump(),
|
|
||||||
user=(
|
|
||||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ModelListResponse(items=models, total=total)
|
|
||||||
|
|
||||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -357,14 +223,6 @@ class ModelsTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_models_by_ids(self, ids: list[str]) -> list[ModelModel]:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
models = db.query(Model).filter(Model.id.in_(ids)).all()
|
|
||||||
return [ModelModel.model_validate(model) for model in models]
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
|
|
@ -386,9 +244,11 @@ class ModelsTable:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# update only the fields that are present in the model
|
# update only the fields that are present in the model
|
||||||
data = model.model_dump(exclude={"id"})
|
result = (
|
||||||
result = db.query(Model).filter_by(id=id).update(data)
|
db.query(Model)
|
||||||
|
.filter_by(id=id)
|
||||||
|
.update(model.model_dump(exclude={"id"}))
|
||||||
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
model = db.get(Model, id)
|
model = db.get(Model, id)
|
||||||
|
|
|
||||||
|
|
@ -7,15 +7,12 @@ from functools import lru_cache
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy import or_, func, select, and_, text
|
||||||
|
|
||||||
|
|
||||||
from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func
|
|
||||||
from sqlalchemy.sql import exists
|
from sqlalchemy.sql import exists
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -26,7 +23,7 @@ from sqlalchemy.sql import exists
|
||||||
class Note(Base):
|
class Note(Base):
|
||||||
__tablename__ = "note"
|
__tablename__ = "note"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
|
|
||||||
title = Column(Text)
|
title = Column(Text)
|
||||||
|
|
@ -78,138 +75,7 @@ class NoteUserResponse(NoteModel):
|
||||||
user: Optional[UserResponse] = None
|
user: Optional[UserResponse] = None
|
||||||
|
|
||||||
|
|
||||||
class NoteItemResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
title: str
|
|
||||||
data: Optional[dict]
|
|
||||||
updated_at: int
|
|
||||||
created_at: int
|
|
||||||
user: Optional[UserResponse] = None
|
|
||||||
|
|
||||||
|
|
||||||
class NoteListResponse(BaseModel):
|
|
||||||
items: list[NoteUserResponse]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class NoteTable:
|
class NoteTable:
|
||||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
|
||||||
group_ids = filter.get("group_ids", [])
|
|
||||||
user_id = filter.get("user_id")
|
|
||||||
dialect_name = db.bind.dialect.name
|
|
||||||
|
|
||||||
conditions = []
|
|
||||||
|
|
||||||
# Handle read_only permission separately
|
|
||||||
if permission == "read_only":
|
|
||||||
# For read_only, we want items where:
|
|
||||||
# 1. User has explicit read permission (via groups or user-level)
|
|
||||||
# 2. BUT does NOT have write permission
|
|
||||||
# 3. Public items are NOT considered read_only
|
|
||||||
|
|
||||||
read_conditions = []
|
|
||||||
|
|
||||||
# Group-level read permission
|
|
||||||
if group_ids:
|
|
||||||
group_read_conditions = []
|
|
||||||
for gid in group_ids:
|
|
||||||
if dialect_name == "sqlite":
|
|
||||||
group_read_conditions.append(
|
|
||||||
Note.access_control["read"]["group_ids"].contains([gid])
|
|
||||||
)
|
|
||||||
elif dialect_name == "postgresql":
|
|
||||||
group_read_conditions.append(
|
|
||||||
cast(
|
|
||||||
Note.access_control["read"]["group_ids"],
|
|
||||||
JSONB,
|
|
||||||
).contains([gid])
|
|
||||||
)
|
|
||||||
|
|
||||||
if group_read_conditions:
|
|
||||||
read_conditions.append(or_(*group_read_conditions))
|
|
||||||
|
|
||||||
# Combine read conditions
|
|
||||||
if read_conditions:
|
|
||||||
has_read = or_(*read_conditions)
|
|
||||||
else:
|
|
||||||
# If no read conditions, return empty result
|
|
||||||
return query.filter(False)
|
|
||||||
|
|
||||||
# Now exclude items where user has write permission
|
|
||||||
write_exclusions = []
|
|
||||||
|
|
||||||
# Exclude items owned by user (they have implicit write)
|
|
||||||
if user_id:
|
|
||||||
write_exclusions.append(Note.user_id != user_id)
|
|
||||||
|
|
||||||
# Exclude items where user has explicit write permission via groups
|
|
||||||
if group_ids:
|
|
||||||
group_write_conditions = []
|
|
||||||
for gid in group_ids:
|
|
||||||
if dialect_name == "sqlite":
|
|
||||||
group_write_conditions.append(
|
|
||||||
Note.access_control["write"]["group_ids"].contains([gid])
|
|
||||||
)
|
|
||||||
elif dialect_name == "postgresql":
|
|
||||||
group_write_conditions.append(
|
|
||||||
cast(
|
|
||||||
Note.access_control["write"]["group_ids"],
|
|
||||||
JSONB,
|
|
||||||
).contains([gid])
|
|
||||||
)
|
|
||||||
|
|
||||||
if group_write_conditions:
|
|
||||||
# User should NOT have write permission
|
|
||||||
write_exclusions.append(~or_(*group_write_conditions))
|
|
||||||
|
|
||||||
# Exclude public items (items without access_control)
|
|
||||||
write_exclusions.append(Note.access_control.isnot(None))
|
|
||||||
write_exclusions.append(cast(Note.access_control, String) != "null")
|
|
||||||
|
|
||||||
# Combine: has read AND does not have write AND not public
|
|
||||||
if write_exclusions:
|
|
||||||
query = query.filter(and_(has_read, *write_exclusions))
|
|
||||||
else:
|
|
||||||
query = query.filter(has_read)
|
|
||||||
|
|
||||||
return query
|
|
||||||
|
|
||||||
# Original logic for other permissions (read, write, etc.)
|
|
||||||
# Public access conditions
|
|
||||||
if group_ids or user_id:
|
|
||||||
conditions.extend(
|
|
||||||
[
|
|
||||||
Note.access_control.is_(None),
|
|
||||||
cast(Note.access_control, String) == "null",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# User-level permission (owner has all permissions)
|
|
||||||
if user_id:
|
|
||||||
conditions.append(Note.user_id == user_id)
|
|
||||||
|
|
||||||
# Group-level permission
|
|
||||||
if group_ids:
|
|
||||||
group_conditions = []
|
|
||||||
for gid in group_ids:
|
|
||||||
if dialect_name == "sqlite":
|
|
||||||
group_conditions.append(
|
|
||||||
Note.access_control[permission]["group_ids"].contains([gid])
|
|
||||||
)
|
|
||||||
elif dialect_name == "postgresql":
|
|
||||||
group_conditions.append(
|
|
||||||
cast(
|
|
||||||
Note.access_control[permission]["group_ids"],
|
|
||||||
JSONB,
|
|
||||||
).contains([gid])
|
|
||||||
)
|
|
||||||
conditions.append(or_(*group_conditions))
|
|
||||||
|
|
||||||
if conditions:
|
|
||||||
query = query.filter(or_(*conditions))
|
|
||||||
|
|
||||||
return query
|
|
||||||
|
|
||||||
def insert_new_note(
|
def insert_new_note(
|
||||||
self,
|
self,
|
||||||
form_data: NoteForm,
|
form_data: NoteForm,
|
||||||
|
|
@ -244,107 +110,15 @@ class NoteTable:
|
||||||
notes = query.all()
|
notes = query.all()
|
||||||
return [NoteModel.model_validate(note) for note in notes]
|
return [NoteModel.model_validate(note) for note in notes]
|
||||||
|
|
||||||
def search_notes(
|
|
||||||
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
|
|
||||||
) -> NoteListResponse:
|
|
||||||
with get_db() as db:
|
|
||||||
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
|
|
||||||
if filter:
|
|
||||||
query_key = filter.get("query")
|
|
||||||
if query_key:
|
|
||||||
query = query.filter(
|
|
||||||
or_(
|
|
||||||
Note.title.ilike(f"%{query_key}%"),
|
|
||||||
cast(Note.data["content"]["md"], Text).ilike(
|
|
||||||
f"%{query_key}%"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
|
||||||
if view_option == "created":
|
|
||||||
query = query.filter(Note.user_id == user_id)
|
|
||||||
elif view_option == "shared":
|
|
||||||
query = query.filter(Note.user_id != user_id)
|
|
||||||
|
|
||||||
# Apply access control filtering
|
|
||||||
if "permission" in filter:
|
|
||||||
permission = filter["permission"]
|
|
||||||
else:
|
|
||||||
permission = "write"
|
|
||||||
|
|
||||||
query = self._has_permission(
|
|
||||||
db,
|
|
||||||
query,
|
|
||||||
filter,
|
|
||||||
permission=permission,
|
|
||||||
)
|
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
|
||||||
direction = filter.get("direction")
|
|
||||||
|
|
||||||
if order_by == "name":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Note.title.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Note.title.desc())
|
|
||||||
elif order_by == "created_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Note.created_at.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Note.created_at.desc())
|
|
||||||
elif order_by == "updated_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Note.updated_at.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Note.updated_at.desc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(Note.updated_at.desc())
|
|
||||||
|
|
||||||
else:
|
|
||||||
query = query.order_by(Note.updated_at.desc())
|
|
||||||
|
|
||||||
# Count BEFORE pagination
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
if skip:
|
|
||||||
query = query.offset(skip)
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
items = query.all()
|
|
||||||
|
|
||||||
notes = []
|
|
||||||
for note, user in items:
|
|
||||||
notes.append(
|
|
||||||
NoteUserResponse(
|
|
||||||
**NoteModel.model_validate(note).model_dump(),
|
|
||||||
user=(
|
|
||||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return NoteListResponse(items=notes, total=total)
|
|
||||||
|
|
||||||
def get_notes_by_user_id(
|
def get_notes_by_user_id(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
permission: str = "read",
|
|
||||||
skip: Optional[int] = None,
|
skip: Optional[int] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> list[NoteModel]:
|
) -> list[NoteModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user_group_ids = [
|
query = db.query(Note).filter(Note.user_id == user_id)
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
query = query.order_by(Note.updated_at.desc())
|
||||||
]
|
|
||||||
|
|
||||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
|
||||||
query = self._has_permission(
|
|
||||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
|
|
||||||
)
|
|
||||||
|
|
||||||
if skip is not None:
|
if skip is not None:
|
||||||
query = query.offset(skip)
|
query = query.offset(skip)
|
||||||
|
|
@ -354,6 +128,56 @@ class NoteTable:
|
||||||
notes = query.all()
|
notes = query.all()
|
||||||
return [NoteModel.model_validate(note) for note in notes]
|
return [NoteModel.model_validate(note) for note in notes]
|
||||||
|
|
||||||
|
def get_notes_by_permission(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
permission: str = "write",
|
||||||
|
skip: Optional[int] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> list[NoteModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||||
|
user_group_ids = {group.id for group in user_groups}
|
||||||
|
|
||||||
|
# Order newest-first. We stream to keep memory usage low.
|
||||||
|
query = (
|
||||||
|
db.query(Note)
|
||||||
|
.order_by(Note.updated_at.desc())
|
||||||
|
.execution_options(stream_results=True)
|
||||||
|
.yield_per(256)
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[NoteModel] = []
|
||||||
|
n_skipped = 0
|
||||||
|
|
||||||
|
for note in query:
|
||||||
|
# Fast-pass #1: owner
|
||||||
|
if note.user_id == user_id:
|
||||||
|
permitted = True
|
||||||
|
# Fast-pass #2: public/open
|
||||||
|
elif note.access_control is None:
|
||||||
|
# Technically this should mean public access for both read and write, but we'll only do read for now
|
||||||
|
# We might want to change this behavior later
|
||||||
|
permitted = permission == "read"
|
||||||
|
else:
|
||||||
|
permitted = has_access(
|
||||||
|
user_id, permission, note.access_control, user_group_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if not permitted:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply skip AFTER permission filtering so it counts only accessible notes
|
||||||
|
if skip and n_skipped < skip:
|
||||||
|
n_skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append(NoteModel.model_validate(note))
|
||||||
|
if limit is not None and len(results) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
note = db.query(Note).filter(Note.id == id).first()
|
note = db.query(Note).filter(Note.id == id).first()
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,13 @@ import json
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, Text, Index
|
from sqlalchemy import BigInteger, Column, String, Text, Index
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# DB MODEL
|
# DB MODEL
|
||||||
|
|
@ -24,7 +25,7 @@ log = logging.getLogger(__name__)
|
||||||
class OAuthSession(Base):
|
class OAuthSession(Base):
|
||||||
__tablename__ = "oauth_session"
|
__tablename__ = "oauth_session"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True)
|
||||||
user_id = Column(Text, nullable=False)
|
user_id = Column(Text, nullable=False)
|
||||||
provider = Column(Text, nullable=False)
|
provider = Column(Text, nullable=False)
|
||||||
token = Column(
|
token = Column(
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,12 @@ from typing import Optional
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index
|
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from open_webui.internal.db import Base, JSONField, get_db
|
||||||
from open_webui.models.users import Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||||
|
|
||||||
|
|
@ -13,6 +14,7 @@ from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Tools DB Schema
|
# Tools DB Schema
|
||||||
|
|
@ -22,7 +24,7 @@ log = logging.getLogger(__name__)
|
||||||
class Tool(Base):
|
class Tool(Base):
|
||||||
__tablename__ = "tool"
|
__tablename__ = "tool"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
|
|
|
||||||
|
|
@ -5,29 +5,14 @@ from open_webui.internal.db import Base, JSONField, get_db
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
||||||
|
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.groups import Groups, GroupMember
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.channels import ChannelMember
|
|
||||||
|
|
||||||
from open_webui.utils.misc import throttle
|
from open_webui.utils.misc import throttle
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import (
|
from sqlalchemy import BigInteger, Column, String, Text, Date
|
||||||
BigInteger,
|
from sqlalchemy import or_
|
||||||
JSON,
|
|
||||||
Column,
|
|
||||||
String,
|
|
||||||
Boolean,
|
|
||||||
Text,
|
|
||||||
Date,
|
|
||||||
exists,
|
|
||||||
select,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
from sqlalchemy import or_, case
|
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
@ -36,71 +21,59 @@ import datetime
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "user"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
name = Column(String)
|
||||||
|
|
||||||
|
email = Column(String)
|
||||||
|
username = Column(String(50), nullable=True)
|
||||||
|
|
||||||
|
role = Column(String)
|
||||||
|
profile_image_url = Column(Text)
|
||||||
|
|
||||||
|
bio = Column(Text, nullable=True)
|
||||||
|
gender = Column(Text, nullable=True)
|
||||||
|
date_of_birth = Column(Date, nullable=True)
|
||||||
|
|
||||||
|
info = Column(JSONField, nullable=True)
|
||||||
|
settings = Column(JSONField, nullable=True)
|
||||||
|
|
||||||
|
api_key = Column(String, nullable=True, unique=True)
|
||||||
|
oauth_sub = Column(Text, unique=True)
|
||||||
|
|
||||||
|
last_active_at = Column(BigInteger)
|
||||||
|
|
||||||
|
updated_at = Column(BigInteger)
|
||||||
|
created_at = Column(BigInteger)
|
||||||
|
|
||||||
|
|
||||||
class UserSettings(BaseModel):
|
class UserSettings(BaseModel):
|
||||||
ui: Optional[dict] = {}
|
ui: Optional[dict] = {}
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
|
||||||
__tablename__ = "user"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
|
||||||
email = Column(String)
|
|
||||||
username = Column(String(50), nullable=True)
|
|
||||||
role = Column(String)
|
|
||||||
|
|
||||||
name = Column(String)
|
|
||||||
|
|
||||||
profile_image_url = Column(Text)
|
|
||||||
profile_banner_image_url = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
bio = Column(Text, nullable=True)
|
|
||||||
gender = Column(Text, nullable=True)
|
|
||||||
date_of_birth = Column(Date, nullable=True)
|
|
||||||
timezone = Column(String, nullable=True)
|
|
||||||
|
|
||||||
presence_state = Column(String, nullable=True)
|
|
||||||
status_emoji = Column(String, nullable=True)
|
|
||||||
status_message = Column(Text, nullable=True)
|
|
||||||
status_expires_at = Column(BigInteger, nullable=True)
|
|
||||||
|
|
||||||
info = Column(JSON, nullable=True)
|
|
||||||
settings = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
oauth = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
last_active_at = Column(BigInteger)
|
|
||||||
updated_at = Column(BigInteger)
|
|
||||||
created_at = Column(BigInteger)
|
|
||||||
|
|
||||||
|
|
||||||
class UserModel(BaseModel):
|
class UserModel(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
email: str
|
email: str
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
|
|
||||||
role: str = "pending"
|
role: str = "pending"
|
||||||
|
|
||||||
name: str
|
|
||||||
|
|
||||||
profile_image_url: str
|
profile_image_url: str
|
||||||
profile_banner_image_url: Optional[str] = None
|
|
||||||
|
|
||||||
bio: Optional[str] = None
|
bio: Optional[str] = None
|
||||||
gender: Optional[str] = None
|
gender: Optional[str] = None
|
||||||
date_of_birth: Optional[datetime.date] = None
|
date_of_birth: Optional[datetime.date] = None
|
||||||
timezone: Optional[str] = None
|
|
||||||
|
|
||||||
presence_state: Optional[str] = None
|
|
||||||
status_emoji: Optional[str] = None
|
|
||||||
status_message: Optional[str] = None
|
|
||||||
status_expires_at: Optional[int] = None
|
|
||||||
|
|
||||||
info: Optional[dict] = None
|
info: Optional[dict] = None
|
||||||
settings: Optional[UserSettings] = None
|
settings: Optional[UserSettings] = None
|
||||||
|
|
||||||
oauth: Optional[dict] = None
|
api_key: Optional[str] = None
|
||||||
|
oauth_sub: Optional[str] = None
|
||||||
|
|
||||||
last_active_at: int # timestamp in epoch
|
last_active_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
@ -109,38 +82,6 @@ class UserModel(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
class UserStatusModel(UserModel):
|
|
||||||
is_active: bool = False
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKey(Base):
|
|
||||||
__tablename__ = "api_key"
|
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
|
||||||
user_id = Column(Text, nullable=False)
|
|
||||||
key = Column(Text, unique=True, nullable=False)
|
|
||||||
data = Column(JSON, nullable=True)
|
|
||||||
expires_at = Column(BigInteger, nullable=True)
|
|
||||||
last_used_at = Column(BigInteger, nullable=True)
|
|
||||||
created_at = Column(BigInteger, nullable=False)
|
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyModel(BaseModel):
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
key: str
|
|
||||||
data: Optional[dict] = None
|
|
||||||
expires_at: Optional[int] = None
|
|
||||||
last_used_at: Optional[int] = None
|
|
||||||
created_at: int # timestamp in epoch
|
|
||||||
updated_at: int # timestamp in epoch
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Forms
|
# Forms
|
||||||
####################
|
####################
|
||||||
|
|
@ -154,31 +95,12 @@ class UpdateProfileForm(BaseModel):
|
||||||
date_of_birth: Optional[datetime.date] = None
|
date_of_birth: Optional[datetime.date] = None
|
||||||
|
|
||||||
|
|
||||||
class UserGroupIdsModel(UserModel):
|
|
||||||
group_ids: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
class UserModelResponse(UserModel):
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
|
|
||||||
|
|
||||||
class UserListResponse(BaseModel):
|
class UserListResponse(BaseModel):
|
||||||
users: list[UserModelResponse]
|
users: list[UserModel]
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class UserGroupIdsListResponse(BaseModel):
|
class UserInfoResponse(BaseModel):
|
||||||
users: list[UserGroupIdsModel]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
class UserStatus(BaseModel):
|
|
||||||
status_emoji: Optional[str] = None
|
|
||||||
status_message: Optional[str] = None
|
|
||||||
status_expires_at: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserInfoResponse(UserStatus):
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
email: str
|
email: str
|
||||||
|
|
@ -190,12 +112,6 @@ class UserIdNameResponse(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class UserIdNameStatusResponse(UserStatus):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
is_active: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserInfoListResponse(BaseModel):
|
class UserInfoListResponse(BaseModel):
|
||||||
users: list[UserInfoResponse]
|
users: list[UserInfoResponse]
|
||||||
total: int
|
total: int
|
||||||
|
|
@ -206,18 +122,18 @@ class UserIdNameListResponse(BaseModel):
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
email: str
|
||||||
|
role: str
|
||||||
|
profile_image_url: str
|
||||||
|
|
||||||
|
|
||||||
class UserNameResponse(BaseModel):
|
class UserNameResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
role: str
|
role: str
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(UserNameResponse):
|
|
||||||
email: str
|
|
||||||
|
|
||||||
|
|
||||||
class UserProfileImageResponse(UserNameResponse):
|
|
||||||
email: str
|
|
||||||
profile_image_url: str
|
profile_image_url: str
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -242,20 +158,20 @@ class UsersTable:
|
||||||
email: str,
|
email: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = "/user.png",
|
||||||
role: str = "pending",
|
role: str = "pending",
|
||||||
oauth: Optional[dict] = None,
|
oauth_sub: Optional[str] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = UserModel(
|
user = UserModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
"email": email,
|
|
||||||
"name": name,
|
"name": name,
|
||||||
|
"email": email,
|
||||||
"role": role,
|
"role": role,
|
||||||
"profile_image_url": profile_image_url,
|
"profile_image_url": profile_image_url,
|
||||||
"last_active_at": int(time.time()),
|
"last_active_at": int(time.time()),
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
"oauth": oauth,
|
"oauth_sub": oauth_sub,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = User(**user.model_dump())
|
result = User(**user.model_dump())
|
||||||
|
|
@ -278,13 +194,8 @@ class UsersTable:
|
||||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = (
|
user = db.query(User).filter_by(api_key=api_key).first()
|
||||||
db.query(User)
|
return UserModel.model_validate(user)
|
||||||
.join(ApiKey, User.id == ApiKey.user_id)
|
|
||||||
.filter(ApiKey.key == api_key)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return UserModel.model_validate(user) if user else None
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -296,23 +207,12 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
|
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db: # type: Session
|
with get_db() as db:
|
||||||
dialect_name = db.bind.dialect.name
|
user = db.query(User).filter_by(oauth_sub=sub).first()
|
||||||
|
return UserModel.model_validate(user)
|
||||||
query = db.query(User)
|
except Exception:
|
||||||
if dialect_name == "sqlite":
|
|
||||||
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
|
|
||||||
elif dialect_name == "postgresql":
|
|
||||||
query = query.filter(
|
|
||||||
User.oauth[provider].cast(JSONB)["sub"].astext == sub
|
|
||||||
)
|
|
||||||
|
|
||||||
user = query.first()
|
|
||||||
return UserModel.model_validate(user) if user else None
|
|
||||||
except Exception as e:
|
|
||||||
# You may want to log the exception here
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_users(
|
def get_users(
|
||||||
|
|
@ -322,7 +222,6 @@ class UsersTable:
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# Join GroupMember so we can order by group_id when requested
|
|
||||||
query = db.query(User)
|
query = db.query(User)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
|
|
@ -335,76 +234,14 @@ class UsersTable:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
channel_id = filter.get("channel_id")
|
|
||||||
if channel_id:
|
|
||||||
query = query.filter(
|
|
||||||
exists(
|
|
||||||
select(ChannelMember.id).where(
|
|
||||||
ChannelMember.user_id == User.id,
|
|
||||||
ChannelMember.channel_id == channel_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
user_ids = filter.get("user_ids")
|
|
||||||
group_ids = filter.get("group_ids")
|
|
||||||
|
|
||||||
if isinstance(user_ids, list) and isinstance(group_ids, list):
|
|
||||||
# If both are empty lists, return no users
|
|
||||||
if not user_ids and not group_ids:
|
|
||||||
return {"users": [], "total": 0}
|
|
||||||
|
|
||||||
if user_ids:
|
|
||||||
query = query.filter(User.id.in_(user_ids))
|
|
||||||
|
|
||||||
if group_ids:
|
|
||||||
query = query.filter(
|
|
||||||
exists(
|
|
||||||
select(GroupMember.id).where(
|
|
||||||
GroupMember.user_id == User.id,
|
|
||||||
GroupMember.group_id.in_(group_ids),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
roles = filter.get("roles")
|
|
||||||
if roles:
|
|
||||||
include_roles = [role for role in roles if not role.startswith("!")]
|
|
||||||
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
|
|
||||||
|
|
||||||
if include_roles:
|
|
||||||
query = query.filter(User.role.in_(include_roles))
|
|
||||||
if exclude_roles:
|
|
||||||
query = query.filter(~User.role.in_(exclude_roles))
|
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get("order_by")
|
||||||
direction = filter.get("direction")
|
direction = filter.get("direction")
|
||||||
|
|
||||||
if order_by and order_by.startswith("group_id:"):
|
if order_by == "name":
|
||||||
group_id = order_by.split(":", 1)[1]
|
|
||||||
|
|
||||||
# Subquery that checks if the user belongs to the group
|
|
||||||
membership_exists = exists(
|
|
||||||
select(GroupMember.id).where(
|
|
||||||
GroupMember.user_id == User.id,
|
|
||||||
GroupMember.group_id == group_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# CASE: user in group → 1, user not in group → 0
|
|
||||||
group_sort = case((membership_exists, 1), else_=0)
|
|
||||||
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(group_sort.asc(), User.name.asc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(group_sort.desc(), User.name.asc())
|
|
||||||
|
|
||||||
elif order_by == "name":
|
|
||||||
if direction == "asc":
|
if direction == "asc":
|
||||||
query = query.order_by(User.name.asc())
|
query = query.order_by(User.name.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.name.desc())
|
query = query.order_by(User.name.desc())
|
||||||
|
|
||||||
elif order_by == "email":
|
elif order_by == "email":
|
||||||
if direction == "asc":
|
if direction == "asc":
|
||||||
query = query.order_by(User.email.asc())
|
query = query.order_by(User.email.asc())
|
||||||
|
|
@ -437,32 +274,18 @@ class UsersTable:
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.created_at.desc())
|
query = query.order_by(User.created_at.desc())
|
||||||
|
|
||||||
# Count BEFORE pagination
|
if skip:
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
# correct pagination logic
|
|
||||||
if skip is not None:
|
|
||||||
query = query.offset(skip)
|
query = query.offset(skip)
|
||||||
if limit is not None:
|
if limit:
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
||||||
users = query.all()
|
users = query.all()
|
||||||
return {
|
return {
|
||||||
"users": [UserModel.model_validate(user) for user in users],
|
"users": [UserModel.model_validate(user) for user in users],
|
||||||
"total": total,
|
"total": db.query(User).count(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_users_by_group_id(self, group_id: str) -> list[UserModel]:
|
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
|
||||||
with get_db() as db:
|
|
||||||
users = (
|
|
||||||
db.query(User)
|
|
||||||
.join(GroupMember, User.id == GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id == group_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [UserModel.model_validate(user) for user in users]
|
|
||||||
|
|
||||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]:
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||||
return [UserModel.model_validate(user) for user in users]
|
return [UserModel.model_validate(user) for user in users]
|
||||||
|
|
@ -499,15 +322,6 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_num_users_active_today(self) -> Optional[int]:
|
|
||||||
with get_db() as db:
|
|
||||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
|
||||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
|
||||||
query = db.query(User).filter(
|
|
||||||
User.last_active_at > today_midnight_timestamp
|
|
||||||
)
|
|
||||||
return query.count()
|
|
||||||
|
|
||||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -518,21 +332,6 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_status_by_id(
|
|
||||||
self, id: str, form_data: UserStatus
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
db.query(User).filter_by(id=id).update(
|
|
||||||
{**form_data.model_dump(exclude_none=True)}
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
user = db.query(User).filter_by(id=id).first()
|
|
||||||
return UserModel.model_validate(user)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_user_profile_image_url_by_id(
|
def update_user_profile_image_url_by_id(
|
||||||
self, id: str, profile_image_url: str
|
self, id: str, profile_image_url: str
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
|
|
@ -549,7 +348,7 @@ class UsersTable:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||||
def update_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
db.query(User).filter_by(id=id).update(
|
db.query(User).filter_by(id=id).update(
|
||||||
|
|
@ -562,35 +361,16 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_oauth_by_id(
|
def update_user_oauth_sub_by_id(
|
||||||
self, id: str, provider: str, sub: str
|
self, id: str, oauth_sub: str
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
"""
|
|
||||||
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
|
|
||||||
Example resulting structure:
|
|
||||||
{
|
|
||||||
"google": { "sub": "123" },
|
|
||||||
"github": { "sub": "abc" }
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
|
||||||
if not user:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Load existing oauth JSON or create empty
|
|
||||||
oauth = user.oauth or {}
|
|
||||||
|
|
||||||
# Update or insert provider entry
|
|
||||||
oauth[provider] = {"sub": sub}
|
|
||||||
|
|
||||||
# Persist updated JSON
|
|
||||||
db.query(User).filter_by(id=id).update({"oauth": oauth})
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
user = db.query(User).filter_by(id=id).first()
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -644,44 +424,22 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
|
||||||
return api_key.key if api_key else None
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
|
||||||
db.commit()
|
db.commit()
|
||||||
|
return True if result == 1 else False
|
||||||
now = int(time.time())
|
|
||||||
new_api_key = ApiKey(
|
|
||||||
id=f"key_{id}",
|
|
||||||
user_id=id,
|
|
||||||
key=api_key,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
|
||||||
db.add(new_api_key)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_user_api_key_by_id(self, id: str) -> bool:
|
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
db.commit()
|
return user.api_key
|
||||||
return True
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return None
|
||||||
|
|
||||||
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
|
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -696,23 +454,5 @@ class UsersTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_active_user_count(self) -> int:
|
|
||||||
with get_db() as db:
|
|
||||||
# Consider user active if last_active_at within the last 3 minutes
|
|
||||||
three_minutes_ago = int(time.time()) - 180
|
|
||||||
count = (
|
|
||||||
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
|
||||||
)
|
|
||||||
return count
|
|
||||||
|
|
||||||
def is_user_active(self, user_id: str) -> bool:
|
|
||||||
with get_db() as db:
|
|
||||||
user = db.query(User).filter_by(id=user_id).first()
|
|
||||||
if user and user.last_active_at:
|
|
||||||
# Consider user active if last_active_at within the last 3 minutes
|
|
||||||
three_minutes_ago = int(time.time()) - 180
|
|
||||||
return user.last_active_at >= three_minutes_ago
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
Users = UsersTable()
|
Users = UsersTable()
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,10 @@ from urllib.parse import quote
|
||||||
from langchain_core.document_loaders import BaseLoader
|
from langchain_core.document_loaders import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from open_webui.utils.headers import include_user_info_headers
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ExternalDocumentLoader(BaseLoader):
|
class ExternalDocumentLoader(BaseLoader):
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from typing import Iterator, List, Union
|
||||||
|
|
||||||
from langchain_core.document_loaders import BaseLoader
|
from langchain_core.document_loaders import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ExternalWebLoader(BaseLoader):
|
class ExternalWebLoader(BaseLoader):
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,11 @@ from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
|
||||||
from open_webui.retrieval.loaders.mineru import MinerULoader
|
from open_webui.retrieval.loaders.mineru import MinerULoader
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import GLOBAL_LOG_LEVEL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
known_source_ext = [
|
known_source_ext = [
|
||||||
"go",
|
"go",
|
||||||
|
|
@ -131,9 +132,8 @@ class TikaLoader:
|
||||||
|
|
||||||
|
|
||||||
class DoclingLoader:
|
class DoclingLoader:
|
||||||
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
|
def __init__(self, url, file_path=None, mime_type=None, params=None):
|
||||||
self.url = url.rstrip("/")
|
self.url = url.rstrip("/")
|
||||||
self.api_key = api_key
|
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.mime_type = mime_type
|
self.mime_type = mime_type
|
||||||
|
|
||||||
|
|
@ -141,25 +141,68 @@ class DoclingLoader:
|
||||||
|
|
||||||
def load(self) -> list[Document]:
|
def load(self) -> list[Document]:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, "rb") as f:
|
||||||
headers = {}
|
files = {
|
||||||
if self.api_key:
|
"files": (
|
||||||
headers["X-Api-Key"] = f"Bearer {self.api_key}"
|
self.file_path,
|
||||||
|
f,
|
||||||
|
self.mime_type or "application/octet-stream",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
r = requests.post(
|
params = {"image_export_mode": "placeholder"}
|
||||||
f"{self.url}/v1/convert/file",
|
|
||||||
files={
|
if self.params:
|
||||||
"files": (
|
if self.params.get("do_picture_description"):
|
||||||
self.file_path,
|
params["do_picture_description"] = self.params.get(
|
||||||
f,
|
"do_picture_description"
|
||||||
self.mime_type or "application/octet-stream",
|
|
||||||
)
|
)
|
||||||
},
|
|
||||||
data={
|
picture_description_mode = self.params.get(
|
||||||
"image_export_mode": "placeholder",
|
"picture_description_mode", ""
|
||||||
**self.params,
|
).lower()
|
||||||
},
|
|
||||||
headers=headers,
|
if picture_description_mode == "local" and self.params.get(
|
||||||
)
|
"picture_description_local", {}
|
||||||
|
):
|
||||||
|
params["picture_description_local"] = json.dumps(
|
||||||
|
self.params.get("picture_description_local", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif picture_description_mode == "api" and self.params.get(
|
||||||
|
"picture_description_api", {}
|
||||||
|
):
|
||||||
|
params["picture_description_api"] = json.dumps(
|
||||||
|
self.params.get("picture_description_api", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
params["do_ocr"] = self.params.get("do_ocr")
|
||||||
|
|
||||||
|
params["force_ocr"] = self.params.get("force_ocr")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.params.get("do_ocr")
|
||||||
|
and self.params.get("ocr_engine")
|
||||||
|
and self.params.get("ocr_lang")
|
||||||
|
):
|
||||||
|
params["ocr_engine"] = self.params.get("ocr_engine")
|
||||||
|
params["ocr_lang"] = [
|
||||||
|
lang.strip()
|
||||||
|
for lang in self.params.get("ocr_lang").split(",")
|
||||||
|
if lang.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.params.get("pdf_backend"):
|
||||||
|
params["pdf_backend"] = self.params.get("pdf_backend")
|
||||||
|
|
||||||
|
if self.params.get("table_mode"):
|
||||||
|
params["table_mode"] = self.params.get("table_mode")
|
||||||
|
|
||||||
|
if self.params.get("pipeline"):
|
||||||
|
params["pipeline"] = self.params.get("pipeline")
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1/convert/file"
|
||||||
|
r = requests.post(endpoint, files=files, data=params)
|
||||||
|
|
||||||
if r.ok:
|
if r.ok:
|
||||||
result = r.json()
|
result = r.json()
|
||||||
document_data = result.get("document", {})
|
document_data = result.get("document", {})
|
||||||
|
|
@ -168,6 +211,7 @@ class DoclingLoader:
|
||||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
||||||
|
|
||||||
log.debug("Docling extracted text: %s", text)
|
log.debug("Docling extracted text: %s", text)
|
||||||
|
|
||||||
return [Document(page_content=text, metadata=metadata)]
|
return [Document(page_content=text, metadata=metadata)]
|
||||||
else:
|
else:
|
||||||
error_msg = f"Error calling Docling API: {r.reason}"
|
error_msg = f"Error calling Docling API: {r.reason}"
|
||||||
|
|
@ -296,7 +340,6 @@ class Loader:
|
||||||
|
|
||||||
loader = DoclingLoader(
|
loader = DoclingLoader(
|
||||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||||
api_key=self.kwargs.get("DOCLING_API_KEY", None),
|
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mime_type=file_content_type,
|
mime_type=file_content_type,
|
||||||
params=params,
|
params=params,
|
||||||
|
|
@ -319,33 +362,22 @@ class Loader:
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loader = AzureAIDocumentIntelligenceLoader(
|
loader = AzureAIDocumentIntelligenceLoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||||
azure_credential=DefaultAzureCredential(),
|
azure_credential=DefaultAzureCredential(),
|
||||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
|
||||||
)
|
)
|
||||||
elif self.engine == "mineru" and file_ext in [
|
elif self.engine == "mineru" and file_ext in [
|
||||||
"pdf"
|
"pdf"
|
||||||
]: # MinerU currently only supports PDF
|
]: # MinerU currently only supports PDF
|
||||||
|
|
||||||
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
|
|
||||||
if mineru_timeout:
|
|
||||||
try:
|
|
||||||
mineru_timeout = int(mineru_timeout)
|
|
||||||
except ValueError:
|
|
||||||
mineru_timeout = 300
|
|
||||||
|
|
||||||
loader = MinerULoader(
|
loader = MinerULoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
|
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
|
||||||
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
|
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
|
||||||
api_key=self.kwargs.get("MINERU_API_KEY", ""),
|
api_key=self.kwargs.get("MINERU_API_KEY", ""),
|
||||||
params=self.kwargs.get("MINERU_PARAMS", {}),
|
params=self.kwargs.get("MINERU_PARAMS", {}),
|
||||||
timeout=mineru_timeout,
|
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
self.engine == "mistral_ocr"
|
self.engine == "mistral_ocr"
|
||||||
|
|
|
||||||
|
|
@ -26,23 +26,20 @@ class MinerULoader:
|
||||||
api_url: str = "http://localhost:8000",
|
api_url: str = "http://localhost:8000",
|
||||||
api_key: str = "",
|
api_key: str = "",
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
timeout: Optional[int] = 300,
|
|
||||||
):
|
):
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.api_mode = api_mode.lower()
|
self.api_mode = api_mode.lower()
|
||||||
self.api_url = api_url.rstrip("/")
|
self.api_url = api_url.rstrip("/")
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.timeout = timeout
|
|
||||||
|
|
||||||
# Parse params dict with defaults
|
# Parse params dict with defaults
|
||||||
self.params = params or {}
|
params = params or {}
|
||||||
self.enable_ocr = params.get("enable_ocr", False)
|
self.enable_ocr = params.get("enable_ocr", False)
|
||||||
self.enable_formula = params.get("enable_formula", True)
|
self.enable_formula = params.get("enable_formula", True)
|
||||||
self.enable_table = params.get("enable_table", True)
|
self.enable_table = params.get("enable_table", True)
|
||||||
self.language = params.get("language", "en")
|
self.language = params.get("language", "en")
|
||||||
self.model_version = params.get("model_version", "pipeline")
|
self.model_version = params.get("model_version", "pipeline")
|
||||||
|
self.page_ranges = params.get("page_ranges", "")
|
||||||
self.page_ranges = self.params.pop("page_ranges", "")
|
|
||||||
|
|
||||||
# Validate API mode
|
# Validate API mode
|
||||||
if self.api_mode not in ["local", "cloud"]:
|
if self.api_mode not in ["local", "cloud"]:
|
||||||
|
|
@ -79,10 +76,27 @@ class MinerULoader:
|
||||||
|
|
||||||
# Build form data for Local API
|
# Build form data for Local API
|
||||||
form_data = {
|
form_data = {
|
||||||
**self.params,
|
|
||||||
"return_md": "true",
|
"return_md": "true",
|
||||||
|
"formula_enable": str(self.enable_formula).lower(),
|
||||||
|
"table_enable": str(self.enable_table).lower(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Parse method based on OCR setting
|
||||||
|
if self.enable_ocr:
|
||||||
|
form_data["parse_method"] = "ocr"
|
||||||
|
else:
|
||||||
|
form_data["parse_method"] = "auto"
|
||||||
|
|
||||||
|
# Language configuration (Local API uses lang_list array)
|
||||||
|
if self.language:
|
||||||
|
form_data["lang_list"] = self.language
|
||||||
|
|
||||||
|
# Backend/model version (Local API uses "backend" parameter)
|
||||||
|
if self.model_version == "vlm":
|
||||||
|
form_data["backend"] = "vlm-vllm-engine"
|
||||||
|
else:
|
||||||
|
form_data["backend"] = "pipeline"
|
||||||
|
|
||||||
# Page ranges (Local API uses start_page_id and end_page_id)
|
# Page ranges (Local API uses start_page_id and end_page_id)
|
||||||
if self.page_ranges:
|
if self.page_ranges:
|
||||||
# For simplicity, if page_ranges is specified, log a warning
|
# For simplicity, if page_ranges is specified, log a warning
|
||||||
|
|
@ -103,7 +117,7 @@ class MinerULoader:
|
||||||
f"{self.api_url}/file_parse",
|
f"{self.api_url}/file_parse",
|
||||||
data=form_data,
|
data=form_data,
|
||||||
files=files,
|
files=files,
|
||||||
timeout=self.timeout,
|
timeout=300, # 5 minute timeout for large documents
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
@ -222,7 +236,10 @@ class MinerULoader:
|
||||||
|
|
||||||
# Build request body
|
# Build request body
|
||||||
request_body = {
|
request_body = {
|
||||||
**self.params,
|
"enable_formula": self.enable_formula,
|
||||||
|
"enable_table": self.enable_table,
|
||||||
|
"language": self.language,
|
||||||
|
"model_version": self.model_version,
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"name": filename,
|
"name": filename,
|
||||||
|
|
@ -302,7 +319,7 @@ class MinerULoader:
|
||||||
response = requests.put(
|
response = requests.put(
|
||||||
upload_url,
|
upload_url,
|
||||||
data=f,
|
data=f,
|
||||||
timeout=self.timeout,
|
timeout=300, # 5 minute timeout for large files
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,11 @@ from typing import List, Dict, Any
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from open_webui.env import GLOBAL_LOG_LEVEL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class MistralLoader:
|
class MistralLoader:
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from typing import Iterator, List, Literal, Union
|
||||||
|
|
||||||
from langchain_core.document_loaders import BaseLoader
|
from langchain_core.document_loaders import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class TavilyLoader(BaseLoader):
|
class TavilyLoader(BaseLoader):
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from xml.etree.ElementTree import ParseError
|
||||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
ALLOWED_SCHEMES = {"http", "https"}
|
ALLOWED_SCHEMES = {"http", "https"}
|
||||||
ALLOWED_NETLOCS = {
|
ALLOWED_NETLOCS = {
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,12 @@ import numpy as np
|
||||||
from colbert.infra import ColBERTConfig
|
from colbert.infra import ColBERTConfig
|
||||||
from colbert.modeling.checkpoint import Checkpoint
|
from colbert.modeling.checkpoint import Checkpoint
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ColBERT(BaseReranker):
|
class ColBERT(BaseReranker):
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,12 @@ from typing import Optional, List, Tuple
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS
|
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||||
from open_webui.utils.headers import include_user_info_headers
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ExternalReranker(BaseReranker):
|
class ExternalReranker(BaseReranker):
|
||||||
|
|
@ -18,12 +18,10 @@ class ExternalReranker(BaseReranker):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
url: str = "http://localhost:8080/v1/rerank",
|
url: str = "http://localhost:8080/v1/rerank",
|
||||||
model: str = "reranker",
|
model: str = "reranker",
|
||||||
timeout: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.url = url
|
self.url = url
|
||||||
self.model = model
|
self.model = model
|
||||||
self.timeout = timeout
|
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, sentences: List[Tuple[str, str]], user=None
|
self, sentences: List[Tuple[str, str]], user=None
|
||||||
|
|
@ -42,19 +40,23 @@ class ExternalReranker(BaseReranker):
|
||||||
log.info(f"ExternalReranker:predict:model {self.model}")
|
log.info(f"ExternalReranker:predict:model {self.model}")
|
||||||
log.info(f"ExternalReranker:predict:query {query}")
|
log.info(f"ExternalReranker:predict:query {query}")
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{self.url}",
|
f"{self.url}",
|
||||||
headers=headers,
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Awaitable, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import time
|
import time
|
||||||
|
|
@ -12,10 +10,7 @@ import re
|
||||||
|
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from langchain_classic.retrievers import (
|
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||||
ContextualCompressionRetriever,
|
|
||||||
EnsembleRetriever,
|
|
||||||
)
|
|
||||||
from langchain_community.retrievers import BM25Retriever
|
from langchain_community.retrievers import BM25Retriever
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
|
@ -32,7 +27,6 @@ from open_webui.models.notes import Notes
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import GetResult
|
from open_webui.retrieval.vector.main import GetResult
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
from open_webui.utils.headers import include_user_info_headers
|
|
||||||
from open_webui.utils.misc import get_message_list
|
from open_webui.utils.misc import get_message_list
|
||||||
|
|
||||||
from open_webui.retrieval.web.utils import get_web_loader
|
from open_webui.retrieval.web.utils import get_web_loader
|
||||||
|
|
@ -40,6 +34,7 @@ from open_webui.retrieval.loaders.youtube import YoutubeLoader
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
OFFLINE_MODE,
|
OFFLINE_MODE,
|
||||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
)
|
)
|
||||||
|
|
@ -50,6 +45,7 @@ from open_webui.config import (
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -92,29 +88,14 @@ class VectorSearchRetriever(BaseRetriever):
|
||||||
top_k: int
|
top_k: int
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
||||||
) -> list[Document]:
|
|
||||||
"""Get documents relevant to a query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: String to find relevant documents for.
|
|
||||||
run_manager: The callback handler to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of relevant documents.
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
|
||||||
result = VECTOR_DB_CLIENT.search(
|
result = VECTOR_DB_CLIENT.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
vectors=[embedding],
|
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
|
||||||
limit=self.top_k,
|
limit=self.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -167,45 +148,7 @@ def get_doc(collection_name: str, user: UserModel = None):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_enriched_texts(collection_result: GetResult) -> list[str]:
|
def query_doc_with_hybrid_search(
|
||||||
enriched_texts = []
|
|
||||||
for idx, text in enumerate(collection_result.documents[0]):
|
|
||||||
metadata = collection_result.metadatas[0][idx]
|
|
||||||
metadata_parts = [text]
|
|
||||||
|
|
||||||
# Add filename (repeat twice for extra weight in BM25 scoring)
|
|
||||||
if metadata.get("name"):
|
|
||||||
filename = metadata["name"]
|
|
||||||
filename_tokens = (
|
|
||||||
filename.replace("_", " ").replace("-", " ").replace(".", " ")
|
|
||||||
)
|
|
||||||
metadata_parts.append(
|
|
||||||
f"Filename: {filename} {filename_tokens} {filename_tokens}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add title if available
|
|
||||||
if metadata.get("title"):
|
|
||||||
metadata_parts.append(f"Title: {metadata['title']}")
|
|
||||||
|
|
||||||
# Add document section headings if available (from markdown splitter)
|
|
||||||
if metadata.get("headings") and isinstance(metadata["headings"], list):
|
|
||||||
headings = " > ".join(str(h) for h in metadata["headings"])
|
|
||||||
metadata_parts.append(f"Section: {headings}")
|
|
||||||
|
|
||||||
# Add source URL/path if available
|
|
||||||
if metadata.get("source"):
|
|
||||||
metadata_parts.append(f"Source: {metadata['source']}")
|
|
||||||
|
|
||||||
# Add snippet for web search results
|
|
||||||
if metadata.get("snippet"):
|
|
||||||
metadata_parts.append(f"Snippet: {metadata['snippet']}")
|
|
||||||
|
|
||||||
enriched_texts.append(" ".join(metadata_parts))
|
|
||||||
|
|
||||||
return enriched_texts
|
|
||||||
|
|
||||||
|
|
||||||
async def query_doc_with_hybrid_search(
|
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
collection_result: GetResult,
|
collection_result: GetResult,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -215,21 +158,12 @@ async def query_doc_with_hybrid_search(
|
||||||
k_reranker: int,
|
k_reranker: int,
|
||||||
r: float,
|
r: float,
|
||||||
hybrid_bm25_weight: float,
|
hybrid_bm25_weight: float,
|
||||||
enable_enriched_texts: bool = False,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
try:
|
try:
|
||||||
# First check if collection_result has the required attributes
|
|
||||||
if (
|
if (
|
||||||
not collection_result
|
not collection_result
|
||||||
or not hasattr(collection_result, "documents")
|
or not hasattr(collection_result, "documents")
|
||||||
or not hasattr(collection_result, "metadatas")
|
or not collection_result.documents
|
||||||
):
|
|
||||||
log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
|
|
||||||
return {"documents": [], "metadatas": [], "distances": []}
|
|
||||||
|
|
||||||
# Now safely check the documents content after confirming attributes exist
|
|
||||||
if (
|
|
||||||
not collection_result.documents
|
|
||||||
or len(collection_result.documents) == 0
|
or len(collection_result.documents) == 0
|
||||||
or not collection_result.documents[0]
|
or not collection_result.documents[0]
|
||||||
):
|
):
|
||||||
|
|
@ -238,14 +172,8 @@ async def query_doc_with_hybrid_search(
|
||||||
|
|
||||||
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
|
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
|
||||||
|
|
||||||
bm25_texts = (
|
|
||||||
get_enriched_texts(collection_result)
|
|
||||||
if enable_enriched_texts
|
|
||||||
else collection_result.documents[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
bm25_retriever = BM25Retriever.from_texts(
|
bm25_retriever = BM25Retriever.from_texts(
|
||||||
texts=bm25_texts,
|
texts=collection_result.documents[0],
|
||||||
metadatas=collection_result.metadatas[0],
|
metadatas=collection_result.metadatas[0],
|
||||||
)
|
)
|
||||||
bm25_retriever.k = k
|
bm25_retriever.k = k
|
||||||
|
|
@ -281,7 +209,7 @@ async def query_doc_with_hybrid_search(
|
||||||
base_compressor=compressor, base_retriever=ensemble_retriever
|
base_compressor=compressor, base_retriever=ensemble_retriever
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await compression_retriever.ainvoke(query)
|
result = compression_retriever.invoke(query)
|
||||||
|
|
||||||
distances = [d.metadata.get("score") for d in result]
|
distances = [d.metadata.get("score") for d in result]
|
||||||
documents = [d.page_content for d in result]
|
documents = [d.page_content for d in result]
|
||||||
|
|
@ -400,7 +328,7 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||||
return merge_get_results(results)
|
return merge_get_results(results)
|
||||||
|
|
||||||
|
|
||||||
async def query_collection(
|
def query_collection(
|
||||||
collection_names: list[str],
|
collection_names: list[str],
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embedding_function,
|
embedding_function,
|
||||||
|
|
@ -425,9 +353,7 @@ async def query_collection(
|
||||||
return None, e
|
return None, e
|
||||||
|
|
||||||
# Generate all query embeddings (in one call)
|
# Generate all query embeddings (in one call)
|
||||||
query_embeddings = await embedding_function(
|
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||||
queries, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
|
||||||
)
|
|
||||||
log.debug(
|
log.debug(
|
||||||
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
||||||
)
|
)
|
||||||
|
|
@ -454,7 +380,7 @@ async def query_collection(
|
||||||
return merge_and_sort_query_results(results, k=k)
|
return merge_and_sort_query_results(results, k=k)
|
||||||
|
|
||||||
|
|
||||||
async def query_collection_with_hybrid_search(
|
def query_collection_with_hybrid_search(
|
||||||
collection_names: list[str],
|
collection_names: list[str],
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embedding_function,
|
embedding_function,
|
||||||
|
|
@ -463,7 +389,6 @@ async def query_collection_with_hybrid_search(
|
||||||
k_reranker: int,
|
k_reranker: int,
|
||||||
r: float,
|
r: float,
|
||||||
hybrid_bm25_weight: float,
|
hybrid_bm25_weight: float,
|
||||||
enable_enriched_texts: bool = False,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
results = []
|
results = []
|
||||||
error = False
|
error = False
|
||||||
|
|
@ -486,9 +411,9 @@ async def query_collection_with_hybrid_search(
|
||||||
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
|
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def process_query(collection_name, query):
|
def process_query(collection_name, query):
|
||||||
try:
|
try:
|
||||||
result = await query_doc_with_hybrid_search(
|
result = query_doc_with_hybrid_search(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
collection_result=collection_results[collection_name],
|
collection_result=collection_results[collection_name],
|
||||||
query=query,
|
query=query,
|
||||||
|
|
@ -498,7 +423,6 @@ async def query_collection_with_hybrid_search(
|
||||||
k_reranker=k_reranker,
|
k_reranker=k_reranker,
|
||||||
r=r,
|
r=r,
|
||||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||||
enable_enriched_texts=enable_enriched_texts,
|
|
||||||
)
|
)
|
||||||
return result, None
|
return result, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -508,16 +432,15 @@ async def query_collection_with_hybrid_search(
|
||||||
# Prepare tasks for all collections and queries
|
# Prepare tasks for all collections and queries
|
||||||
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
|
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
|
||||||
tasks = [
|
tasks = [
|
||||||
(collection_name, query)
|
(cn, q)
|
||||||
for collection_name in collection_names
|
for cn in collection_names
|
||||||
if collection_results[collection_name] is not None
|
if collection_results[cn] is not None
|
||||||
for query in queries
|
for q in queries
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run all queries in parallel using asyncio.gather
|
with ThreadPoolExecutor() as executor:
|
||||||
task_results = await asyncio.gather(
|
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
|
||||||
*[process_query(collection_name, query) for collection_name, query in tasks]
|
task_results = [future.result() for future in future_results]
|
||||||
)
|
|
||||||
|
|
||||||
for result, err in task_results:
|
for result, err in task_results:
|
||||||
if err is not None:
|
if err is not None:
|
||||||
|
|
@ -533,248 +456,6 @@ async def query_collection_with_hybrid_search(
|
||||||
return merge_and_sort_query_results(results, k=k)
|
return merge_and_sort_query_results(results, k=k)
|
||||||
|
|
||||||
|
|
||||||
def generate_openai_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str = "https://api.openai.com/v1",
|
|
||||||
key: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
json_data = {"input": texts, "model": model}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}",
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/embeddings",
|
|
||||||
headers=headers,
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
if "data" in data:
|
|
||||||
return [elem["embedding"] for elem in data["data"]]
|
|
||||||
else:
|
|
||||||
raise "Something went wrong :/"
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def agenerate_openai_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str = "https://api.openai.com/v1",
|
|
||||||
key: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
form_data = {"input": texts, "model": model}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}",
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{url}/embeddings", headers=headers, json=form_data
|
|
||||||
) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
data = await r.json()
|
|
||||||
if "data" in data:
|
|
||||||
return [item["embedding"] for item in data["data"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Something went wrong :/")
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def generate_azure_openai_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str,
|
|
||||||
key: str = "",
|
|
||||||
version: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
json_data = {"input": texts}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"api-key": key,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
url,
|
|
||||||
headers=headers,
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
if r.status_code == 429:
|
|
||||||
retry = float(r.headers.get("Retry-After", "1"))
|
|
||||||
time.sleep(retry)
|
|
||||||
continue
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
if "data" in data:
|
|
||||||
return [elem["embedding"] for elem in data["data"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Something went wrong :/")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def agenerate_azure_openai_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str,
|
|
||||||
key: str = "",
|
|
||||||
version: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
form_data = {"input": texts}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"api-key": key,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
||||||
async with session.post(full_url, headers=headers, json=form_data) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
data = await r.json()
|
|
||||||
if "data" in data:
|
|
||||||
return [item["embedding"] for item in data["data"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Something went wrong :/")
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def generate_ollama_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str,
|
|
||||||
key: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
json_data = {"input": texts, "model": model}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}",
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/api/embed",
|
|
||||||
headers=headers,
|
|
||||||
json=json_data,
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
|
|
||||||
if "embeddings" in data:
|
|
||||||
return data["embeddings"]
|
|
||||||
else:
|
|
||||||
raise "Something went wrong :/"
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def agenerate_ollama_batch_embeddings(
|
|
||||||
model: str,
|
|
||||||
texts: list[str],
|
|
||||||
url: str,
|
|
||||||
key: str = "",
|
|
||||||
prefix: str = None,
|
|
||||||
user: UserModel = None,
|
|
||||||
) -> Optional[list[list[float]]]:
|
|
||||||
try:
|
|
||||||
log.debug(
|
|
||||||
f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
|
||||||
)
|
|
||||||
form_data = {"input": texts, "model": model}
|
|
||||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
|
||||||
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}",
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{url}/api/embed", headers=headers, json=form_data
|
|
||||||
) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
data = await r.json()
|
|
||||||
if "embeddings" in data:
|
|
||||||
return data["embeddings"]
|
|
||||||
else:
|
|
||||||
raise Exception("Something went wrong :/")
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_function(
|
def get_embedding_function(
|
||||||
embedding_engine,
|
embedding_engine,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
|
@ -783,24 +464,13 @@ def get_embedding_function(
|
||||||
key,
|
key,
|
||||||
embedding_batch_size,
|
embedding_batch_size,
|
||||||
azure_api_version=None,
|
azure_api_version=None,
|
||||||
enable_async=True,
|
):
|
||||||
) -> Awaitable:
|
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
# Sentence transformers: CPU-bound sync operation
|
return lambda query, prefix=None, user=None: embedding_function.encode(
|
||||||
async def async_embedding_function(query, prefix=None, user=None):
|
query, **({"prompt": prefix} if prefix else {})
|
||||||
return await asyncio.to_thread(
|
).tolist()
|
||||||
(
|
|
||||||
lambda query, prefix=None: embedding_function.encode(
|
|
||||||
query, **({"prompt": prefix} if prefix else {})
|
|
||||||
).tolist()
|
|
||||||
),
|
|
||||||
query,
|
|
||||||
prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
return async_embedding_function
|
|
||||||
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
||||||
embedding_function = lambda query, prefix=None, user=None: generate_embeddings(
|
func = lambda query, prefix=None, user=None: generate_embeddings(
|
||||||
engine=embedding_engine,
|
engine=embedding_engine,
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
text=query,
|
text=query,
|
||||||
|
|
@ -811,114 +481,41 @@ def get_embedding_function(
|
||||||
azure_api_version=azure_api_version,
|
azure_api_version=azure_api_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_embedding_function(query, prefix=None, user=None):
|
def generate_multiple(query, prefix, user, func):
|
||||||
if isinstance(query, list):
|
if isinstance(query, list):
|
||||||
# Create batches
|
|
||||||
batches = [
|
|
||||||
query[i : i + embedding_batch_size]
|
|
||||||
for i in range(0, len(query), embedding_batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
if enable_async:
|
|
||||||
log.debug(
|
|
||||||
f"generate_multiple_async: Processing {len(batches)} batches in parallel"
|
|
||||||
)
|
|
||||||
# Execute all batches in parallel
|
|
||||||
tasks = [
|
|
||||||
embedding_function(batch, prefix=prefix, user=user)
|
|
||||||
for batch in batches
|
|
||||||
]
|
|
||||||
batch_results = await asyncio.gather(*tasks)
|
|
||||||
else:
|
|
||||||
log.debug(
|
|
||||||
f"generate_multiple_async: Processing {len(batches)} batches sequentially"
|
|
||||||
)
|
|
||||||
batch_results = []
|
|
||||||
for batch in batches:
|
|
||||||
batch_results.append(
|
|
||||||
await embedding_function(batch, prefix=prefix, user=user)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flatten results
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for batch_embeddings in batch_results:
|
for i in range(0, len(query), embedding_batch_size):
|
||||||
|
batch_embeddings = func(
|
||||||
|
query[i : i + embedding_batch_size],
|
||||||
|
prefix=prefix,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(batch_embeddings, list):
|
if isinstance(batch_embeddings, list):
|
||||||
embeddings.extend(batch_embeddings)
|
embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
log.debug(
|
|
||||||
f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches"
|
|
||||||
)
|
|
||||||
return embeddings
|
return embeddings
|
||||||
else:
|
else:
|
||||||
return await embedding_function(query, prefix, user)
|
return func(query, prefix, user)
|
||||||
|
|
||||||
return async_embedding_function
|
return lambda query, prefix=None, user=None: generate_multiple(
|
||||||
|
query, prefix, user, func
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||||
|
|
||||||
|
|
||||||
async def generate_embeddings(
|
|
||||||
engine: str,
|
|
||||||
model: str,
|
|
||||||
text: Union[str, list[str]],
|
|
||||||
prefix: Union[str, None] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
url = kwargs.get("url", "")
|
|
||||||
key = kwargs.get("key", "")
|
|
||||||
user = kwargs.get("user")
|
|
||||||
|
|
||||||
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
|
|
||||||
if isinstance(text, list):
|
|
||||||
text = [f"{prefix}{text_element}" for text_element in text]
|
|
||||||
else:
|
|
||||||
text = f"{prefix}{text}"
|
|
||||||
|
|
||||||
if engine == "ollama":
|
|
||||||
embeddings = await agenerate_ollama_batch_embeddings(
|
|
||||||
**{
|
|
||||||
"model": model,
|
|
||||||
"texts": text if isinstance(text, list) else [text],
|
|
||||||
"url": url,
|
|
||||||
"key": key,
|
|
||||||
"prefix": prefix,
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
|
||||||
elif engine == "openai":
|
|
||||||
embeddings = await agenerate_openai_batch_embeddings(
|
|
||||||
model, text if isinstance(text, list) else [text], url, key, prefix, user
|
|
||||||
)
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
|
||||||
elif engine == "azure_openai":
|
|
||||||
azure_api_version = kwargs.get("azure_api_version", "")
|
|
||||||
embeddings = await agenerate_azure_openai_batch_embeddings(
|
|
||||||
model,
|
|
||||||
text if isinstance(text, list) else [text],
|
|
||||||
url,
|
|
||||||
key,
|
|
||||||
azure_api_version,
|
|
||||||
prefix,
|
|
||||||
user,
|
|
||||||
)
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
|
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
|
||||||
if reranking_function is None:
|
if reranking_function is None:
|
||||||
return None
|
return None
|
||||||
if reranking_engine == "external":
|
if reranking_engine == "external":
|
||||||
return lambda query, documents, user=None: reranking_function.predict(
|
return lambda sentences, user=None: reranking_function.predict(
|
||||||
[(query, doc.page_content) for doc in documents], user=user
|
sentences, user=user
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return lambda query, documents, user=None: reranking_function.predict(
|
return lambda sentences, user=None: reranking_function.predict(sentences)
|
||||||
[(query, doc.page_content) for doc in documents]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_sources_from_items(
|
def get_sources_from_items(
|
||||||
request,
|
request,
|
||||||
items,
|
items,
|
||||||
queries,
|
queries,
|
||||||
|
|
@ -1089,19 +686,23 @@ async def get_sources_from_items(
|
||||||
or knowledge_base.user_id == user.id
|
or knowledge_base.user_id == user.id
|
||||||
or has_access(user.id, "read", knowledge_base.access_control)
|
or has_access(user.id, "read", knowledge_base.access_control)
|
||||||
):
|
):
|
||||||
files = Knowledges.get_files_by_id(knowledge_base.id)
|
|
||||||
|
file_ids = knowledge_base.data.get("file_ids", [])
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
metadatas = []
|
metadatas = []
|
||||||
for file in files:
|
for file_id in file_ids:
|
||||||
documents.append(file.data.get("content", ""))
|
file_object = Files.get_file_by_id(file_id)
|
||||||
metadatas.append(
|
|
||||||
{
|
if file_object:
|
||||||
"file_id": file.id,
|
documents.append(file_object.data.get("content", ""))
|
||||||
"name": file.filename,
|
metadatas.append(
|
||||||
"source": file.filename,
|
{
|
||||||
}
|
"file_id": file_id,
|
||||||
)
|
"name": file_object.filename,
|
||||||
|
"source": file_object.filename,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
query_result = {
|
query_result = {
|
||||||
"documents": [documents],
|
"documents": [documents],
|
||||||
|
|
@ -1142,7 +743,7 @@ async def get_sources_from_items(
|
||||||
query_result = None # Initialize to None
|
query_result = None # Initialize to None
|
||||||
if hybrid_search:
|
if hybrid_search:
|
||||||
try:
|
try:
|
||||||
query_result = await query_collection_with_hybrid_search(
|
query_result = query_collection_with_hybrid_search(
|
||||||
collection_names=collection_names,
|
collection_names=collection_names,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
|
|
@ -1151,7 +752,6 @@ async def get_sources_from_items(
|
||||||
k_reranker=k_reranker,
|
k_reranker=k_reranker,
|
||||||
r=r,
|
r=r,
|
||||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||||
enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(
|
log.debug(
|
||||||
|
|
@ -1160,7 +760,7 @@ async def get_sources_from_items(
|
||||||
|
|
||||||
# fallback to non-hybrid search
|
# fallback to non-hybrid search
|
||||||
if not hybrid_search and query_result is None:
|
if not hybrid_search and query_result is None:
|
||||||
query_result = await query_collection(
|
query_result = query_collection(
|
||||||
collection_names=collection_names,
|
collection_names=collection_names,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
|
|
@ -1236,6 +836,199 @@ def get_model_path(model: str, update_model: bool = False):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def generate_openai_batch_embeddings(
|
||||||
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str = "https://api.openai.com/v1",
|
||||||
|
key: str = "",
|
||||||
|
prefix: str = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
log.debug(
|
||||||
|
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||||
|
)
|
||||||
|
json_data = {"input": texts, "model": model}
|
||||||
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/embeddings",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "data" in data:
|
||||||
|
return [elem["embedding"] for elem in data["data"]]
|
||||||
|
else:
|
||||||
|
raise "Something went wrong :/"
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_azure_openai_batch_embeddings(
|
||||||
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str,
|
||||||
|
key: str = "",
|
||||||
|
version: str = "",
|
||||||
|
prefix: str = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
log.debug(
|
||||||
|
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
||||||
|
)
|
||||||
|
json_data = {"input": texts}
|
||||||
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||||
|
|
||||||
|
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
r = requests.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": key,
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
if r.status_code == 429:
|
||||||
|
retry = float(r.headers.get("Retry-After", "1"))
|
||||||
|
time.sleep(retry)
|
||||||
|
continue
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "data" in data:
|
||||||
|
return [elem["embedding"] for elem in data["data"]]
|
||||||
|
else:
|
||||||
|
raise Exception("Something went wrong :/")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ollama_batch_embeddings(
|
||||||
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str,
|
||||||
|
key: str = "",
|
||||||
|
prefix: str = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
log.debug(
|
||||||
|
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||||
|
)
|
||||||
|
json_data = {"input": texts, "model": model}
|
||||||
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/api/embed",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
|
||||||
|
if "embeddings" in data:
|
||||||
|
return data["embeddings"]
|
||||||
|
else:
|
||||||
|
raise "Something went wrong :/"
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
engine: str,
|
||||||
|
model: str,
|
||||||
|
text: Union[str, list[str]],
|
||||||
|
prefix: Union[str, None] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
url = kwargs.get("url", "")
|
||||||
|
key = kwargs.get("key", "")
|
||||||
|
user = kwargs.get("user")
|
||||||
|
|
||||||
|
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
|
||||||
|
if isinstance(text, list):
|
||||||
|
text = [f"{prefix}{text_element}" for text_element in text]
|
||||||
|
else:
|
||||||
|
text = f"{prefix}{text}"
|
||||||
|
|
||||||
|
if engine == "ollama":
|
||||||
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
|
**{
|
||||||
|
"model": model,
|
||||||
|
"texts": text if isinstance(text, list) else [text],
|
||||||
|
"url": url,
|
||||||
|
"key": key,
|
||||||
|
"prefix": prefix,
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
elif engine == "openai":
|
||||||
|
embeddings = generate_openai_batch_embeddings(
|
||||||
|
model, text if isinstance(text, list) else [text], url, key, prefix, user
|
||||||
|
)
|
||||||
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
elif engine == "azure_openai":
|
||||||
|
azure_api_version = kwargs.get("azure_api_version", "")
|
||||||
|
embeddings = generate_azure_openai_batch_embeddings(
|
||||||
|
model,
|
||||||
|
text if isinstance(text, list) else [text],
|
||||||
|
url,
|
||||||
|
key,
|
||||||
|
azure_api_version,
|
||||||
|
prefix,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
|
@ -1258,38 +1051,19 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
query: str,
|
query: str,
|
||||||
callbacks: Optional[Callbacks] = None,
|
callbacks: Optional[Callbacks] = None,
|
||||||
) -> Sequence[Document]:
|
|
||||||
"""Compress retrieved documents given the query context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: The retrieved documents.
|
|
||||||
query: The query context.
|
|
||||||
callbacks: Optional callbacks to run during compression.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The compressed documents.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def acompress_documents(
|
|
||||||
self,
|
|
||||||
documents: Sequence[Document],
|
|
||||||
query: str,
|
|
||||||
callbacks: Optional[Callbacks] = None,
|
|
||||||
) -> Sequence[Document]:
|
) -> Sequence[Document]:
|
||||||
reranking = self.reranking_function is not None
|
reranking = self.reranking_function is not None
|
||||||
|
|
||||||
scores = None
|
scores = None
|
||||||
if reranking:
|
if reranking:
|
||||||
scores = await asyncio.to_thread(self.reranking_function, query, documents)
|
scores = self.reranking_function(
|
||||||
|
[(query, doc.page_content) for doc in documents]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
|
|
||||||
query_embedding = await self.embedding_function(
|
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
||||||
query, RAG_EMBEDDING_QUERY_PREFIX
|
document_embedding = self.embedding_function(
|
||||||
)
|
|
||||||
document_embedding = await self.embedding_function(
|
|
||||||
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
|
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
|
||||||
)
|
)
|
||||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,10 @@ from open_webui.config import (
|
||||||
CHROMA_CLIENT_AUTH_PROVIDER,
|
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ChromaClient(VectorDBBase):
|
class ChromaClient(VectorDBBase):
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,10 @@ from open_webui.config import (
|
||||||
MILVUS_DISKANN_MAX_DEGREE,
|
MILVUS_DISKANN_MAX_DEGREE,
|
||||||
MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class MilvusClient(VectorDBBase):
|
class MilvusClient(VectorDBBase):
|
||||||
|
|
@ -198,24 +200,23 @@ class MilvusClient(VectorDBBase):
|
||||||
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
||||||
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
||||||
|
|
||||||
|
# Construct the filter string for querying
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace("-", "_")
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
filter_string = " && ".join(
|
||||||
filter_expressions = []
|
[
|
||||||
for key, value in filter.items():
|
f'metadata["{key}"] == {json.dumps(value)}'
|
||||||
if isinstance(value, str):
|
for key, value in filter.items()
|
||||||
filter_expressions.append(f'metadata["{key}"] == "{value}"')
|
]
|
||||||
else:
|
)
|
||||||
filter_expressions.append(f'metadata["{key}"] == {value}')
|
|
||||||
|
|
||||||
filter_string = " && ".join(filter_expressions)
|
|
||||||
|
|
||||||
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
||||||
collection.load()
|
collection.load()
|
||||||
|
all_results = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info(
|
log.info(
|
||||||
|
|
@ -223,25 +224,24 @@ class MilvusClient(VectorDBBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
iterator = collection.query_iterator(
|
iterator = collection.query_iterator(
|
||||||
expr=filter_string,
|
filter=filter_string,
|
||||||
output_fields=[
|
output_fields=[
|
||||||
"id",
|
"id",
|
||||||
"data",
|
"data",
|
||||||
"metadata",
|
"metadata",
|
||||||
],
|
],
|
||||||
limit=limit if limit > 0 else -1,
|
limit=limit, # Pass the limit directly; -1 means no limit.
|
||||||
)
|
)
|
||||||
|
|
||||||
all_results = []
|
|
||||||
while True:
|
while True:
|
||||||
batch = iterator.next()
|
result = iterator.next()
|
||||||
if not batch:
|
if not result:
|
||||||
iterator.close()
|
iterator.close()
|
||||||
break
|
break
|
||||||
all_results.extend(batch)
|
all_results += result
|
||||||
|
|
||||||
log.debug(f"Total results from query: {len(all_results)}")
|
log.info(f"Total results from query: {len(all_results)}")
|
||||||
return self._result_to_get_result([all_results] if all_results else [[]])
|
return self._result_to_get_result([all_results])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(
|
log.exception(
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from open_webui.config import (
|
||||||
MILVUS_HNSW_EFCONSTRUCTION,
|
MILVUS_HNSW_EFCONSTRUCTION,
|
||||||
MILVUS_IVF_FLAT_NLIST,
|
MILVUS_IVF_FLAT_NLIST,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
GetResult,
|
GetResult,
|
||||||
SearchResult,
|
SearchResult,
|
||||||
|
|
@ -28,6 +29,7 @@ from pymilvus import (
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
RESOURCE_ID_FIELD = "resource_id"
|
RESOURCE_ID_FIELD = "resource_id"
|
||||||
|
|
||||||
|
|
@ -155,6 +157,7 @@ class MilvusClient(VectorDBBase):
|
||||||
for item in items
|
for item in items
|
||||||
]
|
]
|
||||||
collection.insert(entities)
|
collection.insert(entities)
|
||||||
|
collection.flush()
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, collection_name: str, vectors: List[List[float]], limit: int
|
self, collection_name: str, vectors: List[List[float]], limit: int
|
||||||
|
|
@ -260,23 +263,15 @@ class MilvusClient(VectorDBBase):
|
||||||
else:
|
else:
|
||||||
expr.append(f"metadata['{key}'] == {value}")
|
expr.append(f"metadata['{key}'] == {value}")
|
||||||
|
|
||||||
iterator = collection.query_iterator(
|
results = collection.query(
|
||||||
expr=" and ".join(expr),
|
expr=" and ".join(expr),
|
||||||
output_fields=["id", "text", "metadata"],
|
output_fields=["id", "text", "metadata"],
|
||||||
limit=limit if limit else -1,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_results = []
|
ids = [res["id"] for res in results]
|
||||||
while True:
|
documents = [res["text"] for res in results]
|
||||||
batch = iterator.next()
|
metadatas = [res["metadata"] for res in results]
|
||||||
if not batch:
|
|
||||||
iterator.close()
|
|
||||||
break
|
|
||||||
all_results.extend(batch)
|
|
||||||
|
|
||||||
ids = [res["id"] for res in all_results]
|
|
||||||
documents = [res["text"] for res in all_results]
|
|
||||||
metadatas = [res["metadata"] for res in all_results]
|
|
||||||
|
|
||||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,10 @@ from open_webui.config import (
|
||||||
ORACLE_DB_POOL_MAX,
|
ORACLE_DB_POOL_MAX,
|
||||||
ORACLE_DB_POOL_INCREMENT,
|
ORACLE_DB_POOL_INCREMENT,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class Oracle23aiClient(VectorDBBase):
|
class Oracle23aiClient(VectorDBBase):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from typing import Optional, List, Dict, Any
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
|
@ -22,7 +22,7 @@ from sqlalchemy.pool import NullPool, QueuePool
|
||||||
|
|
||||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||||
from pgvector.sqlalchemy import Vector, HALFVEC
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy.ext.mutable import MutableDict
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
|
||||||
|
|
@ -44,22 +44,15 @@ from open_webui.config import (
|
||||||
PGVECTOR_POOL_MAX_OVERFLOW,
|
PGVECTOR_POOL_MAX_OVERFLOW,
|
||||||
PGVECTOR_POOL_TIMEOUT,
|
PGVECTOR_POOL_TIMEOUT,
|
||||||
PGVECTOR_POOL_RECYCLE,
|
PGVECTOR_POOL_RECYCLE,
|
||||||
PGVECTOR_INDEX_METHOD,
|
|
||||||
PGVECTOR_HNSW_M,
|
|
||||||
PGVECTOR_HNSW_EF_CONSTRUCTION,
|
|
||||||
PGVECTOR_IVFFLAT_LISTS,
|
|
||||||
PGVECTOR_USE_HALFVEC,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||||
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
|
||||||
|
|
||||||
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
|
|
||||||
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def pgcrypto_encrypt(val, key):
|
def pgcrypto_encrypt(val, key):
|
||||||
|
|
@ -74,7 +67,7 @@ class DocumentChunk(Base):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
|
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||||
collection_name = Column(Text, nullable=False)
|
collection_name = Column(Text, nullable=False)
|
||||||
|
|
||||||
if PGVECTOR_PGCRYPTO:
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
|
@ -164,9 +157,13 @@ class PgvectorClient(VectorDBBase):
|
||||||
connection = self.session.connection()
|
connection = self.session.connection()
|
||||||
Base.metadata.create_all(bind=connection)
|
Base.metadata.create_all(bind=connection)
|
||||||
|
|
||||||
index_method, index_options = self._vector_index_configuration()
|
# Create an index on the vector column if it doesn't exist
|
||||||
self._ensure_vector_index(index_method, index_options)
|
self.session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
||||||
|
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
||||||
|
)
|
||||||
|
)
|
||||||
self.session.execute(
|
self.session.execute(
|
||||||
text(
|
text(
|
||||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||||
|
|
@ -180,78 +177,6 @@ class PgvectorClient(VectorDBBase):
|
||||||
log.exception(f"Error during initialization: {e}")
|
log.exception(f"Error during initialization: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
|
|
||||||
if not index_def:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
after_using = index_def.lower().split("using ", 1)[1]
|
|
||||||
return after_using.split()[0]
|
|
||||||
except (IndexError, AttributeError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _vector_index_configuration(self) -> Tuple[str, str]:
|
|
||||||
if PGVECTOR_INDEX_METHOD:
|
|
||||||
index_method = PGVECTOR_INDEX_METHOD
|
|
||||||
log.info(
|
|
||||||
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
|
|
||||||
index_method,
|
|
||||||
)
|
|
||||||
elif USE_HALFVEC:
|
|
||||||
index_method = "hnsw"
|
|
||||||
log.info(
|
|
||||||
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
|
|
||||||
VECTOR_LENGTH,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
index_method = "ivfflat"
|
|
||||||
|
|
||||||
if index_method == "hnsw":
|
|
||||||
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
|
|
||||||
else:
|
|
||||||
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
|
|
||||||
|
|
||||||
return index_method, index_options
|
|
||||||
|
|
||||||
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
|
|
||||||
index_name = "idx_document_chunk_vector"
|
|
||||||
existing_index_def = self.session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
SELECT indexdef
|
|
||||||
FROM pg_indexes
|
|
||||||
WHERE schemaname = current_schema()
|
|
||||||
AND tablename = 'document_chunk'
|
|
||||||
AND indexname = :index_name
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"index_name": index_name},
|
|
||||||
).scalar()
|
|
||||||
|
|
||||||
existing_method = self._extract_index_method(existing_index_def)
|
|
||||||
if existing_method and existing_method != index_method:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
|
|
||||||
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
|
|
||||||
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
|
|
||||||
"and recreate it with the new method before restarting Open WebUI."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not existing_index_def:
|
|
||||||
index_sql = (
|
|
||||||
f"CREATE INDEX IF NOT EXISTS {index_name} "
|
|
||||||
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
|
|
||||||
)
|
|
||||||
if index_options:
|
|
||||||
index_sql = f"{index_sql} {index_options}"
|
|
||||||
self.session.execute(text(index_sql))
|
|
||||||
log.info(
|
|
||||||
"Ensured vector index '%s' using %s%s.",
|
|
||||||
index_name,
|
|
||||||
index_method,
|
|
||||||
f" {index_options}" if index_options else "",
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_vector_length(self) -> None:
|
def check_vector_length(self) -> None:
|
||||||
"""
|
"""
|
||||||
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
||||||
|
|
@ -271,19 +196,16 @@ class PgvectorClient(VectorDBBase):
|
||||||
if "vector" in document_chunk_table.columns:
|
if "vector" in document_chunk_table.columns:
|
||||||
vector_column = document_chunk_table.columns["vector"]
|
vector_column = document_chunk_table.columns["vector"]
|
||||||
vector_type = vector_column.type
|
vector_type = vector_column.type
|
||||||
expected_type = HALFVEC if USE_HALFVEC else Vector
|
if isinstance(vector_type, Vector):
|
||||||
|
db_vector_length = vector_type.dim
|
||||||
if not isinstance(vector_type, expected_type):
|
if db_vector_length != VECTOR_LENGTH:
|
||||||
|
raise Exception(
|
||||||
|
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||||
|
"Cannot change vector size after initialization without migrating the data."
|
||||||
|
)
|
||||||
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"The 'vector' column type does not match the expected type "
|
"The 'vector' column exists but is not of type 'Vector'."
|
||||||
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
|
|
||||||
)
|
|
||||||
|
|
||||||
db_vector_length = getattr(vector_type, "dim", None)
|
|
||||||
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
|
|
||||||
raise Exception(
|
|
||||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
|
||||||
"Cannot change vector size after initialization without migrating the data."
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
@ -438,11 +360,11 @@ class PgvectorClient(VectorDBBase):
|
||||||
num_queries = len(vectors)
|
num_queries = len(vectors)
|
||||||
|
|
||||||
def vector_expr(vector):
|
def vector_expr(vector):
|
||||||
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
return cast(array(vector), Vector(VECTOR_LENGTH))
|
||||||
|
|
||||||
# Create the values for query vectors
|
# Create the values for query vectors
|
||||||
qid_col = column("qid", Integer)
|
qid_col = column("qid", Integer)
|
||||||
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||||
query_vectors = (
|
query_vectors = (
|
||||||
values(qid_col, q_vector_col)
|
values(qid_col, q_vector_col)
|
||||||
.data(
|
.data(
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from open_webui.config import (
|
||||||
PINECONE_METRIC,
|
PINECONE_METRIC,
|
||||||
PINECONE_CLOUD,
|
PINECONE_CLOUD,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.vector.utils import process_metadata
|
from open_webui.retrieval.vector.utils import process_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,6 +39,7 @@ NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
|
||||||
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
|
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class PineconeClient(VectorDBBase):
|
class PineconeClient(VectorDBBase):
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,12 @@ from open_webui.config import (
|
||||||
QDRANT_TIMEOUT,
|
QDRANT_TIMEOUT,
|
||||||
QDRANT_HNSW_M,
|
QDRANT_HNSW_M,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
NO_LIMIT = 999999999
|
NO_LIMIT = 999999999
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class QdrantClient(VectorDBBase):
|
class QdrantClient(VectorDBBase):
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from open_webui.config import (
|
||||||
QDRANT_TIMEOUT,
|
QDRANT_TIMEOUT,
|
||||||
QDRANT_HNSW_M,
|
QDRANT_HNSW_M,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
GetResult,
|
GetResult,
|
||||||
SearchResult,
|
SearchResult,
|
||||||
|
|
@ -29,6 +30,7 @@ TENANT_ID_FIELD = "tenant_id"
|
||||||
DEFAULT_DIMENSION = 384
|
DEFAULT_DIMENSION = 384
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
|
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,13 @@ from open_webui.retrieval.vector.main import (
|
||||||
SearchResult,
|
SearchResult,
|
||||||
)
|
)
|
||||||
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from typing import List, Optional, Dict, Any, Union
|
from typing import List, Optional, Dict, Any, Union
|
||||||
import logging
|
import logging
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class S3VectorClient(VectorDBBase):
|
class S3VectorClient(VectorDBBase):
|
||||||
|
|
@ -115,16 +117,15 @@ class S3VectorClient(VectorDBBase):
|
||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a vector index exists using direct lookup.
|
Check if a vector index (collection) exists in the S3 vector bucket.
|
||||||
This avoids pagination issues with list_indexes() and is significantly faster.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client.get_index(
|
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||||
vectorBucketName=self.bucket_name, indexName=collection_name
|
indexes = response.get("indexes", [])
|
||||||
)
|
return any(idx.get("indexName") == collection_name for idx in indexes)
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error checking if index '{collection_name}' exists: {e}")
|
log.error(f"Error listing indexes: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str) -> None:
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,340 +0,0 @@
|
||||||
import weaviate
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import (
|
|
||||||
VectorDBBase,
|
|
||||||
VectorItem,
|
|
||||||
SearchResult,
|
|
||||||
GetResult,
|
|
||||||
)
|
|
||||||
from open_webui.retrieval.vector.utils import process_metadata
|
|
||||||
from open_webui.config import (
|
|
||||||
WEAVIATE_HTTP_HOST,
|
|
||||||
WEAVIATE_HTTP_PORT,
|
|
||||||
WEAVIATE_GRPC_PORT,
|
|
||||||
WEAVIATE_API_KEY,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_uuids_to_strings(obj: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Recursively convert UUID objects to strings in nested data structures.
|
|
||||||
|
|
||||||
This function handles:
|
|
||||||
- UUID objects -> string
|
|
||||||
- Dictionaries with UUID values
|
|
||||||
- Lists/Tuples with UUID values
|
|
||||||
- Nested combinations of the above
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: Any object that might contain UUIDs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The same object structure with UUIDs converted to strings
|
|
||||||
"""
|
|
||||||
if isinstance(obj, uuid.UUID):
|
|
||||||
return str(obj)
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
return {key: _convert_uuids_to_strings(value) for key, value in obj.items()}
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
return type(obj)(_convert_uuids_to_strings(item) for item in obj)
|
|
||||||
elif isinstance(obj, (str, int, float, bool, type(None))):
|
|
||||||
return obj
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
class WeaviateClient(VectorDBBase):
|
|
||||||
def __init__(self):
|
|
||||||
self.url = WEAVIATE_HTTP_HOST
|
|
||||||
try:
|
|
||||||
# Build connection parameters
|
|
||||||
connection_params = {
|
|
||||||
"host": WEAVIATE_HTTP_HOST,
|
|
||||||
"port": WEAVIATE_HTTP_PORT,
|
|
||||||
"grpc_port": WEAVIATE_GRPC_PORT,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
|
|
||||||
if WEAVIATE_API_KEY:
|
|
||||||
connection_params["auth_credentials"] = (
|
|
||||||
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client = weaviate.connect_to_local(**connection_params)
|
|
||||||
self.client.connect()
|
|
||||||
except Exception as e:
|
|
||||||
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
|
|
||||||
|
|
||||||
def _sanitize_collection_name(self, collection_name: str) -> str:
|
|
||||||
"""Sanitize collection name to be a valid Weaviate class name."""
|
|
||||||
if not isinstance(collection_name, str) or not collection_name.strip():
|
|
||||||
raise ValueError("Collection name must be a non-empty string")
|
|
||||||
|
|
||||||
# Requirements for a valid Weaviate class name:
|
|
||||||
# The collection name must begin with a capital letter.
|
|
||||||
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
|
|
||||||
|
|
||||||
# Replace hyphens with underscores and keep only alphanumeric characters
|
|
||||||
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
|
|
||||||
name = name.strip("_")
|
|
||||||
|
|
||||||
if not name:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not sanitize collection name to be a valid Weaviate class name"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure it starts with a letter and is capitalized
|
|
||||||
if not name[0].isalpha():
|
|
||||||
name = "C" + name
|
|
||||||
|
|
||||||
return name[0].upper() + name[1:]
|
|
||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
return self.client.collections.exists(sane_collection_name)
|
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str) -> None:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
if self.client.collections.exists(sane_collection_name):
|
|
||||||
self.client.collections.delete(sane_collection_name)
|
|
||||||
|
|
||||||
def _create_collection(self, collection_name: str) -> None:
|
|
||||||
self.client.collections.create(
|
|
||||||
name=collection_name,
|
|
||||||
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
|
|
||||||
properties=[
|
|
||||||
weaviate.classes.config.Property(
|
|
||||||
name="text", data_type=weaviate.classes.config.DataType.TEXT
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
if not self.client.collections.exists(sane_collection_name):
|
|
||||||
self._create_collection(sane_collection_name)
|
|
||||||
|
|
||||||
collection = self.client.collections.get(sane_collection_name)
|
|
||||||
|
|
||||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
|
||||||
for item in items:
|
|
||||||
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
|
|
||||||
|
|
||||||
properties = {"text": item["text"]}
|
|
||||||
if item["metadata"]:
|
|
||||||
clean_metadata = _convert_uuids_to_strings(
|
|
||||||
process_metadata(item["metadata"])
|
|
||||||
)
|
|
||||||
clean_metadata.pop("text", None)
|
|
||||||
properties.update(clean_metadata)
|
|
||||||
|
|
||||||
batch.add_object(
|
|
||||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
if not self.client.collections.exists(sane_collection_name):
|
|
||||||
self._create_collection(sane_collection_name)
|
|
||||||
|
|
||||||
collection = self.client.collections.get(sane_collection_name)
|
|
||||||
|
|
||||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
|
||||||
for item in items:
|
|
||||||
item_uuid = str(item["id"]) if item["id"] else None
|
|
||||||
|
|
||||||
properties = {"text": item["text"]}
|
|
||||||
if item["metadata"]:
|
|
||||||
clean_metadata = _convert_uuids_to_strings(
|
|
||||||
process_metadata(item["metadata"])
|
|
||||||
)
|
|
||||||
clean_metadata.pop("text", None)
|
|
||||||
properties.update(clean_metadata)
|
|
||||||
|
|
||||||
batch.add_object(
|
|
||||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def search(
|
|
||||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
|
||||||
) -> Optional[SearchResult]:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
if not self.client.collections.exists(sane_collection_name):
|
|
||||||
return None
|
|
||||||
|
|
||||||
collection = self.client.collections.get(sane_collection_name)
|
|
||||||
|
|
||||||
result_ids, result_documents, result_metadatas, result_distances = (
|
|
||||||
[],
|
|
||||||
[],
|
|
||||||
[],
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
for vector_embedding in vectors:
|
|
||||||
try:
|
|
||||||
response = collection.query.near_vector(
|
|
||||||
near_vector=vector_embedding,
|
|
||||||
limit=limit,
|
|
||||||
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
ids = [str(obj.uuid) for obj in response.objects]
|
|
||||||
documents = []
|
|
||||||
metadatas = []
|
|
||||||
distances = []
|
|
||||||
|
|
||||||
for obj in response.objects:
|
|
||||||
properties = dict(obj.properties) if obj.properties else {}
|
|
||||||
documents.append(properties.pop("text", ""))
|
|
||||||
metadatas.append(_convert_uuids_to_strings(properties))
|
|
||||||
|
|
||||||
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
|
|
||||||
raw_distances = [
|
|
||||||
(
|
|
||||||
obj.metadata.distance
|
|
||||||
if obj.metadata and obj.metadata.distance
|
|
||||||
else 2.0
|
|
||||||
)
|
|
||||||
for obj in response.objects
|
|
||||||
]
|
|
||||||
distances = [(2 - dist) / 2 for dist in raw_distances]
|
|
||||||
|
|
||||||
result_ids.append(ids)
|
|
||||||
result_documents.append(documents)
|
|
||||||
result_metadatas.append(metadatas)
|
|
||||||
result_distances.append(distances)
|
|
||||||
except Exception:
|
|
||||||
result_ids.append([])
|
|
||||||
result_documents.append([])
|
|
||||||
result_metadatas.append([])
|
|
||||||
result_distances.append([])
|
|
||||||
|
|
||||||
return SearchResult(
|
|
||||||
**{
|
|
||||||
"ids": result_ids,
|
|
||||||
"documents": result_documents,
|
|
||||||
"metadatas": result_metadatas,
|
|
||||||
"distances": result_distances,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
|
||||||
) -> Optional[GetResult]:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
if not self.client.collections.exists(sane_collection_name):
|
|
||||||
return None
|
|
||||||
|
|
||||||
collection = self.client.collections.get(sane_collection_name)
|
|
||||||
|
|
||||||
weaviate_filter = None
|
|
||||||
if filter:
|
|
||||||
for key, value in filter.items():
|
|
||||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
|
|
||||||
value
|
|
||||||
)
|
|
||||||
weaviate_filter = (
|
|
||||||
prop_filter
|
|
||||||
if weaviate_filter is None
|
|
||||||
else weaviate.classes.query.Filter.all_of(
|
|
||||||
[weaviate_filter, prop_filter]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = collection.query.fetch_objects(
|
|
||||||
filters=weaviate_filter, limit=limit
|
|
||||||
)
|
|
||||||
|
|
||||||
ids = [str(obj.uuid) for obj in response.objects]
|
|
||||||
documents = []
|
|
||||||
metadatas = []
|
|
||||||
|
|
||||||
for obj in response.objects:
|
|
||||||
properties = dict(obj.properties) if obj.properties else {}
|
|
||||||
documents.append(properties.pop("text", ""))
|
|
||||||
metadatas.append(_convert_uuids_to_strings(properties))
|
|
||||||
|
|
||||||
return GetResult(
|
|
||||||
**{
|
|
||||||
"ids": [ids],
|
|
||||||
"documents": [documents],
|
|
||||||
"metadatas": [metadatas],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
if not self.client.collections.exists(sane_collection_name):
|
|
||||||
return None
|
|
||||||
|
|
||||||
collection = self.client.collections.get(sane_collection_name)
|
|
||||||
ids, documents, metadatas = [], [], []
|
|
||||||
|
|
||||||
try:
|
|
||||||
for item in collection.iterator():
|
|
||||||
ids.append(str(item.uuid))
|
|
||||||
properties = dict(item.properties) if item.properties else {}
|
|
||||||
documents.append(properties.pop("text", ""))
|
|
||||||
metadatas.append(_convert_uuids_to_strings(properties))
|
|
||||||
|
|
||||||
if not ids:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return GetResult(
|
|
||||||
**{
|
|
||||||
"ids": [ids],
|
|
||||||
"documents": [documents],
|
|
||||||
"metadatas": [metadatas],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
collection_name: str,
|
|
||||||
ids: Optional[List[str]] = None,
|
|
||||||
filter: Optional[Dict] = None,
|
|
||||||
) -> None:
|
|
||||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
|
||||||
if not self.client.collections.exists(sane_collection_name):
|
|
||||||
return
|
|
||||||
|
|
||||||
collection = self.client.collections.get(sane_collection_name)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if ids:
|
|
||||||
for item_id in ids:
|
|
||||||
collection.data.delete_by_id(uuid=item_id)
|
|
||||||
elif filter:
|
|
||||||
weaviate_filter = None
|
|
||||||
for key, value in filter.items():
|
|
||||||
prop_filter = weaviate.classes.query.Filter.by_property(
|
|
||||||
name=key
|
|
||||||
).equal(value)
|
|
||||||
weaviate_filter = (
|
|
||||||
prop_filter
|
|
||||||
if weaviate_filter is None
|
|
||||||
else weaviate.classes.query.Filter.all_of(
|
|
||||||
[weaviate_filter, prop_filter]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if weaviate_filter:
|
|
||||||
collection.data.delete_many(where=weaviate_filter)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
try:
|
|
||||||
for collection_name in self.client.collections.list_all().keys():
|
|
||||||
self.client.collections.delete(collection_name)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
@ -67,10 +67,6 @@ class Vector:
|
||||||
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
|
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
|
||||||
|
|
||||||
return Oracle23aiClient()
|
return Oracle23aiClient()
|
||||||
case VectorType.WEAVIATE:
|
|
||||||
from open_webui.retrieval.vector.dbs.weaviate import WeaviateClient
|
|
||||||
|
|
||||||
return WeaviateClient()
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported vector type: {vector_type}")
|
raise ValueError(f"Unsupported vector type: {vector_type}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,4 +11,3 @@ class VectorType(StrEnum):
|
||||||
PGVECTOR = "pgvector"
|
PGVECTOR = "pgvector"
|
||||||
ORACLE23AI = "oracle23ai"
|
ORACLE23AI = "oracle23ai"
|
||||||
S3VECTOR = "s3vector"
|
S3VECTOR = "s3vector"
|
||||||
WEAVIATE = "weaviate"
|
|
||||||
|
|
|
||||||
|
|
@ -1,126 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
"""
|
|
||||||
Azure AI Search integration for Open WebUI.
|
|
||||||
Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python
|
|
||||||
|
|
||||||
Required package: azure-search-documents
|
|
||||||
Install: pip install azure-search-documents
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def search_azure(
|
|
||||||
api_key: str,
|
|
||||||
endpoint: str,
|
|
||||||
index_name: str,
|
|
||||||
query: str,
|
|
||||||
count: int,
|
|
||||||
filter_list: Optional[list[str]] = None,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""
|
|
||||||
Search using Azure AI Search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: Azure Search API key (query key or admin key)
|
|
||||||
endpoint: Azure Search service endpoint (e.g., https://myservice.search.windows.net)
|
|
||||||
index_name: Name of the search index to query
|
|
||||||
query: Search query string
|
|
||||||
count: Number of results to return
|
|
||||||
filter_list: Optional list of domains to filter results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of SearchResult objects with link, title, and snippet
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from azure.core.credentials import AzureKeyCredential
|
|
||||||
from azure.search.documents import SearchClient
|
|
||||||
except ImportError:
|
|
||||||
log.error(
|
|
||||||
"azure-search-documents package is not installed. "
|
|
||||||
"Install it with: pip install azure-search-documents"
|
|
||||||
)
|
|
||||||
raise ImportError(
|
|
||||||
"azure-search-documents is required for Azure AI Search. "
|
|
||||||
"Install it with: pip install azure-search-documents"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create search client with API key authentication
|
|
||||||
credential = AzureKeyCredential(api_key)
|
|
||||||
search_client = SearchClient(
|
|
||||||
endpoint=endpoint, index_name=index_name, credential=credential
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform the search
|
|
||||||
results = search_client.search(search_text=query, top=count)
|
|
||||||
|
|
||||||
# Convert results to list and extract fields
|
|
||||||
search_results = []
|
|
||||||
for result in results:
|
|
||||||
# Azure AI Search returns documents with custom schemas
|
|
||||||
# We need to extract common fields that might represent URL, title, and content
|
|
||||||
# Common field names to look for:
|
|
||||||
result_dict = dict(result)
|
|
||||||
|
|
||||||
# Try to find URL field (common names)
|
|
||||||
link = (
|
|
||||||
result_dict.get("url")
|
|
||||||
or result_dict.get("link")
|
|
||||||
or result_dict.get("uri")
|
|
||||||
or result_dict.get("metadata_storage_path")
|
|
||||||
or ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to find title field (common names)
|
|
||||||
title = (
|
|
||||||
result_dict.get("title")
|
|
||||||
or result_dict.get("name")
|
|
||||||
or result_dict.get("metadata_title")
|
|
||||||
or result_dict.get("metadata_storage_name")
|
|
||||||
or None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to find content/snippet field (common names)
|
|
||||||
snippet = (
|
|
||||||
result_dict.get("content")
|
|
||||||
or result_dict.get("snippet")
|
|
||||||
or result_dict.get("description")
|
|
||||||
or result_dict.get("summary")
|
|
||||||
or result_dict.get("text")
|
|
||||||
or None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Truncate snippet if too long
|
|
||||||
if snippet and len(snippet) > 500:
|
|
||||||
snippet = snippet[:497] + "..."
|
|
||||||
|
|
||||||
if link: # Only add if we found a valid link
|
|
||||||
search_results.append(
|
|
||||||
{
|
|
||||||
"link": link,
|
|
||||||
"title": title,
|
|
||||||
"snippet": snippet,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply domain filtering if specified
|
|
||||||
if filter_list:
|
|
||||||
search_results = get_filtered_results(search_results, filter_list)
|
|
||||||
|
|
||||||
# Convert to SearchResult objects
|
|
||||||
return [
|
|
||||||
SearchResult(
|
|
||||||
link=result["link"],
|
|
||||||
title=result.get("title"),
|
|
||||||
snippet=result.get("snippet"),
|
|
||||||
)
|
|
||||||
for result in search_results
|
|
||||||
]
|
|
||||||
|
|
||||||
except Exception as ex:
|
|
||||||
log.error(f"Azure AI Search error: {ex}")
|
|
||||||
raise ex
|
|
||||||
|
|
@ -4,9 +4,11 @@ from pprint import pprint
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
"""
|
"""
|
||||||
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
|
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,20 @@ from typing import Optional
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def _parse_response(response):
|
def _parse_response(response):
|
||||||
results = []
|
result = {}
|
||||||
if "data" in response:
|
if "data" in response:
|
||||||
data = response["data"]
|
data = response["data"]
|
||||||
if "webPages" in data:
|
if "webPages" in data:
|
||||||
webPages = data["webPages"]
|
webPages = data["webPages"]
|
||||||
if "value" in webPages:
|
if "value" in webPages:
|
||||||
results = [
|
result["webpage"] = [
|
||||||
{
|
{
|
||||||
"id": item.get("id", ""),
|
"id": item.get("id", ""),
|
||||||
"name": item.get("name", ""),
|
"name": item.get("name", ""),
|
||||||
|
|
@ -29,7 +31,7 @@ def _parse_response(response):
|
||||||
}
|
}
|
||||||
for item in webPages["value"]
|
for item in webPages["value"]
|
||||||
]
|
]
|
||||||
return results
|
return result
|
||||||
|
|
||||||
|
|
||||||
def search_bocha(
|
def search_bocha(
|
||||||
|
|
@ -51,7 +53,7 @@ def search_bocha(
|
||||||
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
results = _parse_response(response.json())
|
results = _parse_response(response.json())
|
||||||
|
print(results)
|
||||||
if filter_list:
|
if filter_list:
|
||||||
results = get_filtered_results(results, filter_list)
|
results = get_filtered_results(results, filter_list)
|
||||||
|
|
||||||
|
|
@ -59,5 +61,5 @@ def search_bocha(
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
||||||
)
|
)
|
||||||
for result in results[:count]
|
for result in results.get("webpage", [])[:count]
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_brave(
|
def search_brave(
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from typing import Optional
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
from ddgs import DDGS
|
from ddgs import DDGS
|
||||||
from ddgs.exceptions import RatelimitException
|
from ddgs.exceptions import RatelimitException
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_duckduckgo(
|
def search_duckduckgo(
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,11 @@ from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.web.main import SearchResult
|
from open_webui.retrieval.web.main import SearchResult
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
EXA_API_BASE = "https://api.exa.ai"
|
EXA_API_BASE = "https://api.exa.ai"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,40 +2,27 @@ import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
from open_webui.utils.headers import include_user_info_headers
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_external(
|
def search_external(
|
||||||
request: Request,
|
|
||||||
external_url: str,
|
external_url: str,
|
||||||
external_api_key: str,
|
external_api_key: str,
|
||||||
query: str,
|
query: str,
|
||||||
count: int,
|
count: int,
|
||||||
filter_list: Optional[List[str]] = None,
|
filter_list: Optional[List[str]] = None,
|
||||||
user=None,
|
|
||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
try:
|
try:
|
||||||
headers = {
|
|
||||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
|
||||||
"Authorization": f"Bearer {external_api_key}",
|
|
||||||
}
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
chat_id = getattr(request.state, "chat_id", None)
|
|
||||||
if chat_id:
|
|
||||||
headers["X-OpenWebUI-Chat-Id"] = str(chat_id)
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
external_url,
|
external_url,
|
||||||
headers=headers,
|
headers={
|
||||||
|
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||||
|
"Authorization": f"Bearer {external_api_key}",
|
||||||
|
},
|
||||||
json={
|
json={
|
||||||
"query": query,
|
"query": query,
|
||||||
"count": count,
|
"count": count,
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,11 @@ import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_firecrawl(
|
def search_firecrawl(
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_google_pse(
|
def search_google_pse(
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,11 @@ import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult
|
from open_webui.retrieval.web.main import SearchResult
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_kagi(
|
def search_kagi(
|
||||||
|
|
|
||||||
|
|
@ -5,38 +5,18 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from open_webui.retrieval.web.utils import resolve_hostname
|
|
||||||
from open_webui.utils.misc import is_string_allowed
|
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_results(results, filter_list):
|
def get_filtered_results(results, filter_list):
|
||||||
if not filter_list:
|
if not filter_list:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
filtered_results = []
|
filtered_results = []
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
url = result.get("url") or result.get("link", "") or result.get("href", "")
|
url = result.get("url") or result.get("link", "") or result.get("href", "")
|
||||||
if not validators.url(url):
|
if not validators.url(url):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
domain = urlparse(url).netloc
|
domain = urlparse(url).netloc
|
||||||
if not domain:
|
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
|
||||||
continue
|
|
||||||
|
|
||||||
hostnames = [domain]
|
|
||||||
|
|
||||||
try:
|
|
||||||
ipv4_addresses, ipv6_addresses = resolve_hostname(domain)
|
|
||||||
hostnames.extend(ipv4_addresses)
|
|
||||||
hostnames.extend(ipv6_addresses)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if is_string_allowed(hostnames, filter_list):
|
|
||||||
filtered_results.append(result)
|
filtered_results.append(result)
|
||||||
continue
|
|
||||||
|
|
||||||
return filtered_results
|
return filtered_results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_mojeek(
|
def search_mojeek(
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,11 @@ from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
from open_webui.retrieval.web.main import SearchResult
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_ollama_cloud(
|
def search_ollama_cloud(
|
||||||
|
|
@ -36,9 +38,6 @@ def search_ollama_cloud(
|
||||||
results = data.get("results", [])
|
results = data.get("results", [])
|
||||||
log.info(f"Found {len(results)} results")
|
log.info(f"Found {len(results)} results")
|
||||||
|
|
||||||
if filter_list:
|
|
||||||
results = get_filtered_results(results, filter_list)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result.get("url", ""),
|
link=result.get("url", ""),
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from typing import Optional, Literal
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
MODELS = Literal[
|
MODELS = Literal[
|
||||||
"sonar",
|
"sonar",
|
||||||
|
|
@ -15,6 +16,7 @@ SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_perplexity(
|
def search_perplexity(
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,11 @@ from typing import Optional, Literal
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
from open_webui.utils.headers import include_user_info_headers
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_perplexity_search(
|
def search_perplexity_search(
|
||||||
|
|
@ -14,8 +15,6 @@ def search_perplexity_search(
|
||||||
query: str,
|
query: str,
|
||||||
count: int,
|
count: int,
|
||||||
filter_list: Optional[list[str]] = None,
|
filter_list: Optional[list[str]] = None,
|
||||||
api_url: str = "https://api.perplexity.ai/search",
|
|
||||||
user=None,
|
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||||
|
|
||||||
|
|
@ -24,8 +23,6 @@ def search_perplexity_search(
|
||||||
query (str): The query to search for
|
query (str): The query to search for
|
||||||
count (int): Maximum number of results to return
|
count (int): Maximum number of results to return
|
||||||
filter_list (Optional[list[str]]): List of domains to filter results
|
filter_list (Optional[list[str]]): List of domains to filter results
|
||||||
api_url (str): Custom API URL (defaults to https://api.perplexity.ai/search)
|
|
||||||
user: Optional user object for forwarding user info headers
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -33,11 +30,8 @@ def search_perplexity_search(
|
||||||
if hasattr(api_key, "__str__"):
|
if hasattr(api_key, "__str__"):
|
||||||
api_key = str(api_key)
|
api_key = str(api_key)
|
||||||
|
|
||||||
if hasattr(api_url, "__str__"):
|
|
||||||
api_url = str(api_url)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = api_url
|
url = "https://api.perplexity.ai/search"
|
||||||
|
|
||||||
# Create payload for the API call
|
# Create payload for the API call
|
||||||
payload = {
|
payload = {
|
||||||
|
|
@ -50,10 +44,6 @@ def search_perplexity_search(
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Forward user info headers if user is provided
|
|
||||||
if user is not None:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
# Make the API request
|
# Make the API request
|
||||||
response = requests.request("POST", url, json=payload, headers=headers)
|
response = requests.request("POST", url, json=payload, headers=headers)
|
||||||
# Parse the JSON response
|
# Parse the JSON response
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from urllib.parse import urlencode
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_searchapi(
|
def search_searchapi(
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_searxng(
|
def search_searxng(
|
||||||
|
|
@ -25,7 +27,7 @@ def search_searxng(
|
||||||
count (int): The maximum number of results to retrieve from the search.
|
count (int): The maximum number of results to retrieve from the search.
|
||||||
|
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
language (str): Language filter for the search results; e.g., "all", "en-US", "es". Defaults to "all".
|
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
|
||||||
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
|
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
|
||||||
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
|
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
|
||||||
categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
|
categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
|
||||||
|
|
@ -38,7 +40,7 @@ def search_searxng(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Default values for optional parameters are provided as empty strings or None when not specified.
|
# Default values for optional parameters are provided as empty strings or None when not specified.
|
||||||
language = kwargs.get("language", "all")
|
language = kwargs.get("language", "en-US")
|
||||||
safesearch = kwargs.get("safesearch", "1")
|
safesearch = kwargs.get("safesearch", "1")
|
||||||
time_range = kwargs.get("time_range", "")
|
time_range = kwargs.get("time_range", "")
|
||||||
categories = "".join(kwargs.get("categories", []))
|
categories = "".join(kwargs.get("categories", []))
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from urllib.parse import urlencode
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_serpapi(
|
def search_serpapi(
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_serper(
|
def search_serper(
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from urllib.parse import urlencode
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_serply(
|
def search_serply(
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_serpstack(
|
def search_serpstack(
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from typing import Optional, List
|
||||||
|
|
||||||
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_sougou(
|
def search_sougou(
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_tavily(
|
def search_tavily(
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ import validators
|
||||||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||||
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
@ -33,51 +32,24 @@ from open_webui.config import (
|
||||||
PLAYWRIGHT_WS_URL,
|
PLAYWRIGHT_WS_URL,
|
||||||
PLAYWRIGHT_TIMEOUT,
|
PLAYWRIGHT_TIMEOUT,
|
||||||
WEB_LOADER_ENGINE,
|
WEB_LOADER_ENGINE,
|
||||||
WEB_LOADER_TIMEOUT,
|
|
||||||
FIRECRAWL_API_BASE_URL,
|
FIRECRAWL_API_BASE_URL,
|
||||||
FIRECRAWL_API_KEY,
|
FIRECRAWL_API_KEY,
|
||||||
TAVILY_API_KEY,
|
TAVILY_API_KEY,
|
||||||
TAVILY_EXTRACT_DEPTH,
|
TAVILY_EXTRACT_DEPTH,
|
||||||
EXTERNAL_WEB_LOADER_URL,
|
EXTERNAL_WEB_LOADER_URL,
|
||||||
EXTERNAL_WEB_LOADER_API_KEY,
|
EXTERNAL_WEB_LOADER_API_KEY,
|
||||||
WEB_FETCH_FILTER_LIST,
|
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import is_string_allowed
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
def resolve_hostname(hostname):
|
|
||||||
# Get address information
|
|
||||||
addr_info = socket.getaddrinfo(hostname, None)
|
|
||||||
|
|
||||||
# Extract IP addresses from address information
|
|
||||||
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
|
||||||
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
|
||||||
|
|
||||||
return ipv4_addresses, ipv6_addresses
|
|
||||||
|
|
||||||
|
|
||||||
def validate_url(url: Union[str, Sequence[str]]):
|
def validate_url(url: Union[str, Sequence[str]]):
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
if isinstance(validators.url(url), validators.ValidationError):
|
if isinstance(validators.url(url), validators.ValidationError):
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
parsed_url = urllib.parse.urlparse(url)
|
|
||||||
|
|
||||||
# Protocol validation - only allow http/https
|
|
||||||
if parsed_url.scheme not in ["http", "https"]:
|
|
||||||
log.warning(
|
|
||||||
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
|
|
||||||
)
|
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
||||||
|
|
||||||
# Blocklist check using unified filtering logic
|
|
||||||
if WEB_FETCH_FILTER_LIST:
|
|
||||||
if not is_string_allowed(url, WEB_FETCH_FILTER_LIST):
|
|
||||||
log.warning(f"URL blocked by filter list: {url}")
|
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
||||||
|
|
||||||
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
||||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||||
parsed_url = urllib.parse.urlparse(url)
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
|
@ -110,6 +82,17 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||||
return valid_urls
|
return valid_urls
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hostname(hostname):
|
||||||
|
# Get address information
|
||||||
|
addr_info = socket.getaddrinfo(hostname, None)
|
||||||
|
|
||||||
|
# Extract IP addresses from address information
|
||||||
|
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||||
|
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||||
|
|
||||||
|
return ipv4_addresses, ipv6_addresses
|
||||||
|
|
||||||
|
|
||||||
def extract_metadata(soup, url):
|
def extract_metadata(soup, url):
|
||||||
metadata = {"source": url}
|
metadata = {"source": url}
|
||||||
if title := soup.find("title"):
|
if title := soup.find("title"):
|
||||||
|
|
@ -659,10 +642,6 @@ def get_web_loader(
|
||||||
# Check if the URLs are valid
|
# Check if the URLs are valid
|
||||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||||
|
|
||||||
if not safe_urls:
|
|
||||||
log.warning(f"All provided URLs were blocked or invalid: {urls}")
|
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
||||||
|
|
||||||
web_loader_args = {
|
web_loader_args = {
|
||||||
"web_paths": safe_urls,
|
"web_paths": safe_urls,
|
||||||
"verify_ssl": verify_ssl,
|
"verify_ssl": verify_ssl,
|
||||||
|
|
@ -673,20 +652,6 @@ def get_web_loader(
|
||||||
|
|
||||||
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
|
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
|
||||||
WebLoaderClass = SafeWebBaseLoader
|
WebLoaderClass = SafeWebBaseLoader
|
||||||
|
|
||||||
request_kwargs = {}
|
|
||||||
if WEB_LOADER_TIMEOUT.value:
|
|
||||||
try:
|
|
||||||
timeout_value = float(WEB_LOADER_TIMEOUT.value)
|
|
||||||
except ValueError:
|
|
||||||
timeout_value = None
|
|
||||||
|
|
||||||
if timeout_value:
|
|
||||||
request_kwargs["timeout"] = timeout_value
|
|
||||||
|
|
||||||
if request_kwargs:
|
|
||||||
web_loader_args["requests_kwargs"] = request_kwargs
|
|
||||||
|
|
||||||
if WEB_LOADER_ENGINE.value == "playwright":
|
if WEB_LOADER_ENGINE.value == "playwright":
|
||||||
WebLoaderClass = SafePlaywrightURLLoader
|
WebLoaderClass = SafePlaywrightURLLoader
|
||||||
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value
|
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from typing import Optional
|
||||||
import requests
|
import requests
|
||||||
from requests.auth import HTTPDigestAuth
|
from requests.auth import HTTPDigestAuth
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_yacy(
|
def search_yacy(
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ import aiohttp
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import requests
|
import requests
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
from urllib.parse import urljoin, quote
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
|
|
@ -33,9 +34,7 @@ from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.misc import strict_match_mime_type
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.headers import include_user_info_headers
|
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
WHISPER_MODEL_AUTO_UPDATE,
|
WHISPER_MODEL_AUTO_UPDATE,
|
||||||
WHISPER_MODEL_DIR,
|
WHISPER_MODEL_DIR,
|
||||||
|
|
@ -49,6 +48,7 @@ from open_webui.env import (
|
||||||
ENV,
|
ENV,
|
||||||
AIOHTTP_CLIENT_SESSION_SSL,
|
AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
AIOHTTP_CLIENT_TIMEOUT,
|
AIOHTTP_CLIENT_TIMEOUT,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
DEVICE_TYPE,
|
DEVICE_TYPE,
|
||||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
)
|
)
|
||||||
|
|
@ -63,6 +63,7 @@ AZURE_MAX_FILE_SIZE_MB = 200
|
||||||
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||||
|
|
||||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -363,17 +364,23 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
**(request.app.state.config.TTS_OPENAI_PARAMS or {}),
|
**(request.app.state.config.TTS_OPENAI_PARAMS or {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
r = await session.post(
|
r = await session.post(
|
||||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=headers,
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -563,7 +570,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
return FileResponse(file_path)
|
return FileResponse(file_path)
|
||||||
|
|
||||||
|
|
||||||
def transcription_handler(request, file_path, metadata, user=None):
|
def transcription_handler(request, file_path, metadata):
|
||||||
filename = os.path.basename(file_path)
|
filename = os.path.basename(file_path)
|
||||||
file_dir = os.path.dirname(file_path)
|
file_dir = os.path.dirname(file_path)
|
||||||
id = filename.split(".")[0]
|
id = filename.split(".")[0]
|
||||||
|
|
@ -614,15 +621,11 @@ def transcription_handler(request, file_path, metadata, user=None):
|
||||||
if language:
|
if language:
|
||||||
payload["language"] = language
|
payload["language"] = language
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
|
||||||
}
|
|
||||||
if user and ENABLE_FORWARD_USER_INFO_HEADERS:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
||||||
headers=headers,
|
headers={
|
||||||
|
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||||
|
},
|
||||||
files={"file": (filename, open(file_path, "rb"))},
|
files={"file": (filename, open(file_path, "rb"))},
|
||||||
data=payload,
|
data=payload,
|
||||||
)
|
)
|
||||||
|
|
@ -1024,9 +1027,7 @@ def transcription_handler(request, file_path, metadata, user=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
||||||
request: Request, file_path: str, metadata: Optional[dict] = None, user=None
|
|
||||||
):
|
|
||||||
log.info(f"transcribe: {file_path} {metadata}")
|
log.info(f"transcribe: {file_path} {metadata}")
|
||||||
|
|
||||||
if is_audio_conversion_required(file_path):
|
if is_audio_conversion_required(file_path):
|
||||||
|
|
@ -1053,9 +1054,7 @@ def transcribe(
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
# Submit tasks for each chunk_path
|
# Submit tasks for each chunk_path
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(
|
executor.submit(transcription_handler, request, chunk_path, metadata)
|
||||||
transcription_handler, request, chunk_path, metadata, user
|
|
||||||
)
|
|
||||||
for chunk_path in chunk_paths
|
for chunk_path in chunk_paths
|
||||||
]
|
]
|
||||||
# Gather results as they complete
|
# Gather results as they complete
|
||||||
|
|
@ -1151,11 +1150,20 @@ def transcription(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
log.info(f"file.content_type: {file.content_type}")
|
log.info(f"file.content_type: {file.content_type}")
|
||||||
|
|
||||||
stt_supported_content_types = getattr(
|
stt_supported_content_types = getattr(
|
||||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||||
)
|
)
|
||||||
|
|
||||||
if not strict_match_mime_type(stt_supported_content_types, file.content_type):
|
if not any(
|
||||||
|
fnmatch(file.content_type, content_type)
|
||||||
|
for content_type in (
|
||||||
|
stt_supported_content_types
|
||||||
|
if stt_supported_content_types
|
||||||
|
and any(t.strip() for t in stt_supported_content_types)
|
||||||
|
else ["audio/*", "video/webm"]
|
||||||
|
)
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||||
|
|
@ -1181,7 +1189,7 @@ def transcription(
|
||||||
if language:
|
if language:
|
||||||
metadata = {"language": language}
|
metadata = {"language": language}
|
||||||
|
|
||||||
result = transcribe(request, file_path, metadata, user)
|
result = transcribe(request, file_path, metadata)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**result,
|
**result,
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@ import time
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
import urllib
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.auths import (
|
from open_webui.models.auths import (
|
||||||
AddUserForm,
|
AddUserForm,
|
||||||
|
|
@ -17,13 +15,9 @@ from open_webui.models.auths import (
|
||||||
SigninResponse,
|
SigninResponse,
|
||||||
SignupForm,
|
SignupForm,
|
||||||
UpdatePasswordForm,
|
UpdatePasswordForm,
|
||||||
|
UserResponse,
|
||||||
)
|
)
|
||||||
from open_webui.models.users import (
|
from open_webui.models.users import Users, UpdateProfileForm
|
||||||
UserProfileImageResponse,
|
|
||||||
Users,
|
|
||||||
UpdateProfileForm,
|
|
||||||
UserStatus,
|
|
||||||
)
|
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.oauth_sessions import OAuthSessions
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
||||||
|
|
@ -37,23 +31,16 @@ from open_webui.env import (
|
||||||
WEBUI_AUTH_COOKIE_SECURE,
|
WEBUI_AUTH_COOKIE_SECURE,
|
||||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||||
ENABLE_INITIAL_ADMIN_SIGNUP,
|
ENABLE_INITIAL_ADMIN_SIGNUP,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
)
|
)
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.responses import RedirectResponse, Response, JSONResponse
|
from fastapi.responses import RedirectResponse, Response, JSONResponse
|
||||||
from open_webui.config import (
|
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||||
OPENID_PROVIDER_URL,
|
|
||||||
ENABLE_OAUTH_SIGNUP,
|
|
||||||
ENABLE_LDAP,
|
|
||||||
ENABLE_PASSWORD_AUTH,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||||
from open_webui.utils.auth import (
|
from open_webui.utils.auth import (
|
||||||
validate_password,
|
|
||||||
verify_password,
|
|
||||||
decode_token,
|
decode_token,
|
||||||
invalidate_token,
|
|
||||||
create_api_key,
|
create_api_key,
|
||||||
create_token,
|
create_token,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
|
|
@ -63,12 +50,7 @@ from open_webui.utils.auth import (
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
)
|
)
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.access_control import get_permissions, has_permission
|
from open_webui.utils.access_control import get_permissions
|
||||||
from open_webui.utils.groups import apply_default_group_assignment
|
|
||||||
|
|
||||||
from open_webui.utils.redis import get_redis_client
|
|
||||||
from open_webui.utils.rate_limit import RateLimiter
|
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
|
|
@ -80,22 +62,19 @@ from ldap3.utils.conv import escape_filter_chars
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
signin_rate_limiter = RateLimiter(
|
|
||||||
redis_client=get_redis_client(), limit=5 * 3, window=60 * 3
|
|
||||||
)
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetSessionUser
|
# GetSessionUser
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
class SessionUserResponse(Token, UserProfileImageResponse):
|
class SessionUserResponse(Token, UserResponse):
|
||||||
expires_at: Optional[int] = None
|
expires_at: Optional[int] = None
|
||||||
permissions: Optional[dict] = None
|
permissions: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class SessionUserInfoResponse(SessionUserResponse, UserStatus):
|
class SessionUserInfoResponse(SessionUserResponse):
|
||||||
bio: Optional[str] = None
|
bio: Optional[str] = None
|
||||||
gender: Optional[str] = None
|
gender: Optional[str] = None
|
||||||
date_of_birth: Optional[datetime.date] = None
|
date_of_birth: Optional[datetime.date] = None
|
||||||
|
|
@ -152,9 +131,6 @@ async def get_session_user(
|
||||||
"bio": user.bio,
|
"bio": user.bio,
|
||||||
"gender": user.gender,
|
"gender": user.gender,
|
||||||
"date_of_birth": user.date_of_birth,
|
"date_of_birth": user.date_of_birth,
|
||||||
"status_emoji": user.status_emoji,
|
|
||||||
"status_message": user.status_message,
|
|
||||||
"status_expires_at": user.status_expires_at,
|
|
||||||
"permissions": user_permissions,
|
"permissions": user_permissions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -164,7 +140,7 @@ async def get_session_user(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update/profile", response_model=UserProfileImageResponse)
|
@router.post("/update/profile", response_model=UserResponse)
|
||||||
async def update_profile(
|
async def update_profile(
|
||||||
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
@ -193,19 +169,13 @@ async def update_password(
|
||||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
||||||
if session_user:
|
if session_user:
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(session_user.email, form_data.password)
|
||||||
session_user.email, lambda pw: verify_password(form_data.password, pw)
|
|
||||||
)
|
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
try:
|
|
||||||
validate_password(form_data.password)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(400, detail=str(e))
|
|
||||||
hashed = get_password_hash(form_data.new_password)
|
hashed = get_password_hash(form_data.new_password)
|
||||||
return Auths.update_user_password_by_id(user.id, hashed)
|
return Auths.update_user_password_by_id(user.id, hashed)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
|
@ -215,17 +185,7 @@ async def update_password(
|
||||||
############################
|
############################
|
||||||
@router.post("/ldap", response_model=SessionUserResponse)
|
@router.post("/ldap", response_model=SessionUserResponse)
|
||||||
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
# Security checks FIRST - before loading any config
|
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
|
||||||
if not request.app.state.config.ENABLE_LDAP:
|
|
||||||
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
|
||||||
|
|
||||||
if not ENABLE_PASSWORD_AUTH:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOW load LDAP config variables
|
|
||||||
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
||||||
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
|
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
|
||||||
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
|
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
|
||||||
|
|
@ -246,6 +206,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
else "ALL"
|
else "ALL"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not ENABLE_LDAP:
|
||||||
|
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tls = Tls(
|
tls = Tls(
|
||||||
validate=LDAP_VALIDATE_CERT,
|
validate=LDAP_VALIDATE_CERT,
|
||||||
|
|
@ -286,11 +249,13 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
f"{LDAP_ATTRIBUTE_FOR_MAIL}",
|
f"{LDAP_ATTRIBUTE_FOR_MAIL}",
|
||||||
"cn",
|
"cn",
|
||||||
]
|
]
|
||||||
|
|
||||||
if ENABLE_LDAP_GROUP_MANAGEMENT:
|
if ENABLE_LDAP_GROUP_MANAGEMENT:
|
||||||
search_attributes.append(f"{LDAP_ATTRIBUTE_FOR_GROUPS}")
|
search_attributes.append(f"{LDAP_ATTRIBUTE_FOR_GROUPS}")
|
||||||
log.info(
|
log.info(
|
||||||
f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes"
|
f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes"
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"LDAP search attributes: {search_attributes}")
|
log.info(f"LDAP search attributes: {search_attributes}")
|
||||||
|
|
||||||
search_success = connection_app.search(
|
search_success = connection_app.search(
|
||||||
|
|
@ -298,22 +263,15 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
|
search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
|
||||||
attributes=search_attributes,
|
attributes=search_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not search_success or not connection_app.entries:
|
if not search_success or not connection_app.entries:
|
||||||
raise HTTPException(400, detail="User not found in the LDAP server")
|
raise HTTPException(400, detail="User not found in the LDAP server")
|
||||||
|
|
||||||
entry = connection_app.entries[0]
|
entry = connection_app.entries[0]
|
||||||
entry_username = entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"].value
|
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||||
email = entry[
|
email = entry[
|
||||||
f"{LDAP_ATTRIBUTE_FOR_MAIL}"
|
f"{LDAP_ATTRIBUTE_FOR_MAIL}"
|
||||||
].value # retrieve the Attribute value
|
].value # retrieve the Attribute value
|
||||||
|
|
||||||
username_list = [] # list of usernames from LDAP attribute
|
|
||||||
if isinstance(entry_username, list):
|
|
||||||
username_list = [str(name).lower() for name in entry_username]
|
|
||||||
else:
|
|
||||||
username_list = [str(entry_username).lower()]
|
|
||||||
|
|
||||||
# TODO: support multiple emails if LDAP returns a list
|
|
||||||
if not email:
|
if not email:
|
||||||
raise HTTPException(400, "User does not have a valid email address.")
|
raise HTTPException(400, "User does not have a valid email address.")
|
||||||
elif isinstance(email, str):
|
elif isinstance(email, str):
|
||||||
|
|
@ -323,13 +281,13 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
else:
|
else:
|
||||||
email = str(email).lower()
|
email = str(email).lower()
|
||||||
|
|
||||||
cn = str(entry["cn"]) # common name
|
cn = str(entry["cn"])
|
||||||
user_dn = entry.entry_dn # user distinguished name
|
user_dn = entry.entry_dn
|
||||||
|
|
||||||
user_groups = []
|
user_groups = []
|
||||||
if ENABLE_LDAP_GROUP_MANAGEMENT and LDAP_ATTRIBUTE_FOR_GROUPS in entry:
|
if ENABLE_LDAP_GROUP_MANAGEMENT and LDAP_ATTRIBUTE_FOR_GROUPS in entry:
|
||||||
group_dns = entry[LDAP_ATTRIBUTE_FOR_GROUPS]
|
group_dns = entry[LDAP_ATTRIBUTE_FOR_GROUPS]
|
||||||
log.info(f"LDAP raw group DNs for user {username_list}: {group_dns}")
|
log.info(f"LDAP raw group DNs for user {username}: {group_dns}")
|
||||||
|
|
||||||
if group_dns:
|
if group_dns:
|
||||||
log.info(f"LDAP group_dns original: {group_dns}")
|
log.info(f"LDAP group_dns original: {group_dns}")
|
||||||
|
|
@ -380,16 +338,16 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
f"LDAP groups for user {username_list}: {user_groups} (total: {len(user_groups)})"
|
f"LDAP groups for user {username}: {user_groups} (total: {len(user_groups)})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info(f"No groups found for user {username_list}")
|
log.info(f"No groups found for user {username}")
|
||||||
elif ENABLE_LDAP_GROUP_MANAGEMENT:
|
elif ENABLE_LDAP_GROUP_MANAGEMENT:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry"
|
f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry"
|
||||||
)
|
)
|
||||||
|
|
||||||
if username_list and form_data.user.lower() in username_list:
|
if username == form_data.user.lower():
|
||||||
connection_user = Connection(
|
connection_user = Connection(
|
||||||
server,
|
server,
|
||||||
user_dn,
|
user_dn,
|
||||||
|
|
@ -421,11 +379,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_default_group_assignment(
|
|
||||||
request.app.state.config.DEFAULT_GROUP_ID,
|
|
||||||
user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|
@ -474,6 +427,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
):
|
):
|
||||||
if ENABLE_LDAP_GROUP_CREATION:
|
if ENABLE_LDAP_GROUP_CREATION:
|
||||||
Groups.create_groups_by_group_names(user.id, user_groups)
|
Groups.create_groups_by_group_names(user.id, user_groups)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Groups.sync_groups_by_group_names(user.id, user_groups)
|
Groups.sync_groups_by_group_names(user.id, user_groups)
|
||||||
log.info(
|
log.info(
|
||||||
|
|
@ -509,12 +463,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
|
|
||||||
@router.post("/signin", response_model=SessionUserResponse)
|
@router.post("/signin", response_model=SessionUserResponse)
|
||||||
async def signin(request: Request, response: Response, form_data: SigninForm):
|
async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
if not ENABLE_PASSWORD_AUTH:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
||||||
|
|
@ -524,10 +472,6 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
|
|
||||||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||||
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
||||||
try:
|
|
||||||
name = urllib.parse.unquote(name, encoding="utf-8")
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not Users.get_user_by_email(email.lower()):
|
if not Users.get_user_by_email(email.lower()):
|
||||||
await signup(
|
await signup(
|
||||||
|
|
@ -551,9 +495,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
admin_password = "admin"
|
admin_password = "admin"
|
||||||
|
|
||||||
if Users.get_user_by_email(admin_email.lower()):
|
if Users.get_user_by_email(admin_email.lower()):
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
||||||
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if Users.has_users():
|
if Users.has_users():
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||||
|
|
@ -564,16 +506,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
SignupForm(email=admin_email, password=admin_password, name="User"),
|
SignupForm(email=admin_email, password=admin_password, name="User"),
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
||||||
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if signin_rate_limiter.is_limited(form_data.email.lower()):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
|
|
||||||
)
|
|
||||||
|
|
||||||
password_bytes = form_data.password.encode("utf-8")
|
password_bytes = form_data.password.encode("utf-8")
|
||||||
if len(password_bytes) > 72:
|
if len(password_bytes) > 72:
|
||||||
# TODO: Implement other hashing algorithms that support longer passwords
|
# TODO: Implement other hashing algorithms that support longer passwords
|
||||||
|
|
@ -583,9 +517,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
# decode safely — ignore incomplete UTF-8 sequences
|
# decode safely — ignore incomplete UTF-8 sequences
|
||||||
form_data.password = password_bytes.decode("utf-8", errors="ignore")
|
form_data.password = password_bytes.decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
|
||||||
form_data.email.lower(), lambda pw: verify_password(form_data.password, pw)
|
|
||||||
)
|
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
|
||||||
|
|
@ -667,14 +599,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
validate_password(form_data.password)
|
|
||||||
except Exception as e:
|
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||||
raise HTTPException(400, detail=str(e))
|
if len(form_data.password.encode("utf-8")) > 72:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
||||||
|
)
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
|
|
||||||
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
form_data.email.lower(),
|
form_data.email.lower(),
|
||||||
hashed,
|
hashed,
|
||||||
|
|
@ -730,11 +664,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
# Disable signup after the first user is created
|
# Disable signup after the first user is created
|
||||||
request.app.state.config.ENABLE_SIGNUP = False
|
request.app.state.config.ENABLE_SIGNUP = False
|
||||||
|
|
||||||
apply_default_group_assignment(
|
|
||||||
request.app.state.config.DEFAULT_GROUP_ID,
|
|
||||||
user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
|
|
@ -755,19 +684,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
|
||||||
@router.get("/signout")
|
@router.get("/signout")
|
||||||
async def signout(request: Request, response: Response):
|
async def signout(request: Request, response: Response):
|
||||||
|
|
||||||
# get auth token from headers or cookies
|
|
||||||
token = None
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
if auth_header:
|
|
||||||
auth_cred = get_http_authorization_cred(auth_header)
|
|
||||||
token = auth_cred.credentials
|
|
||||||
else:
|
|
||||||
token = request.cookies.get("token")
|
|
||||||
|
|
||||||
if token:
|
|
||||||
await invalidate_token(request, token)
|
|
||||||
|
|
||||||
response.delete_cookie("token")
|
response.delete_cookie("token")
|
||||||
response.delete_cookie("oui-session")
|
response.delete_cookie("oui-session")
|
||||||
response.delete_cookie("oauth_id_token")
|
response.delete_cookie("oauth_id_token")
|
||||||
|
|
@ -838,9 +754,7 @@ async def signout(request: Request, response: Response):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/add", response_model=SigninResponse)
|
@router.post("/add", response_model=SigninResponse)
|
||||||
async def add_user(
|
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||||
request: Request, form_data: AddUserForm, user=Depends(get_admin_user)
|
|
||||||
):
|
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
|
|
@ -850,11 +764,6 @@ async def add_user(
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
|
||||||
validate_password(form_data.password)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(400, detail=str(e))
|
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
form_data.email.lower(),
|
form_data.email.lower(),
|
||||||
|
|
@ -865,11 +774,6 @@ async def add_user(
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
apply_default_group_assignment(
|
|
||||||
request.app.state.config.DEFAULT_GROUP_ID,
|
|
||||||
user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
token = create_token(data={"id": user.id})
|
token = create_token(data={"id": user.id})
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
|
@ -931,15 +835,13 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||||
"ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS,
|
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||||
"ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
|
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||||
"API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS,
|
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||||
"DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID,
|
|
||||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||||
"ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS,
|
|
||||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||||
|
|
@ -953,15 +855,13 @@ class AdminConfig(BaseModel):
|
||||||
SHOW_ADMIN_DETAILS: bool
|
SHOW_ADMIN_DETAILS: bool
|
||||||
WEBUI_URL: str
|
WEBUI_URL: str
|
||||||
ENABLE_SIGNUP: bool
|
ENABLE_SIGNUP: bool
|
||||||
ENABLE_API_KEYS: bool
|
ENABLE_API_KEY: bool
|
||||||
ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS: bool
|
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
|
||||||
API_KEYS_ALLOWED_ENDPOINTS: str
|
API_KEY_ALLOWED_ENDPOINTS: str
|
||||||
DEFAULT_USER_ROLE: str
|
DEFAULT_USER_ROLE: str
|
||||||
DEFAULT_GROUP_ID: str
|
|
||||||
JWT_EXPIRES_IN: str
|
JWT_EXPIRES_IN: str
|
||||||
ENABLE_COMMUNITY_SHARING: bool
|
ENABLE_COMMUNITY_SHARING: bool
|
||||||
ENABLE_MESSAGE_RATING: bool
|
ENABLE_MESSAGE_RATING: bool
|
||||||
ENABLE_FOLDERS: bool
|
|
||||||
ENABLE_CHANNELS: bool
|
ENABLE_CHANNELS: bool
|
||||||
ENABLE_NOTES: bool
|
ENABLE_NOTES: bool
|
||||||
ENABLE_USER_WEBHOOKS: bool
|
ENABLE_USER_WEBHOOKS: bool
|
||||||
|
|
@ -978,23 +878,20 @@ async def update_admin_config(
|
||||||
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
|
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
|
||||||
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
||||||
|
|
||||||
request.app.state.config.ENABLE_API_KEYS = form_data.ENABLE_API_KEYS
|
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
|
||||||
request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = (
|
request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
|
||||||
form_data.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS
|
form_data.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
|
||||||
)
|
)
|
||||||
request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS = (
|
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS = (
|
||||||
form_data.API_KEYS_ALLOWED_ENDPOINTS
|
form_data.API_KEY_ALLOWED_ENDPOINTS
|
||||||
)
|
)
|
||||||
|
|
||||||
request.app.state.config.ENABLE_FOLDERS = form_data.ENABLE_FOLDERS
|
|
||||||
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
||||||
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
|
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
|
||||||
|
|
||||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||||
|
|
||||||
request.app.state.config.DEFAULT_GROUP_ID = form_data.DEFAULT_GROUP_ID
|
|
||||||
|
|
||||||
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
|
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
|
||||||
|
|
||||||
# Check if the input string matches the pattern
|
# Check if the input string matches the pattern
|
||||||
|
|
@ -1021,15 +918,13 @@ async def update_admin_config(
|
||||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||||
"ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS,
|
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||||
"ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
|
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||||
"API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS,
|
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||||
"DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID,
|
|
||||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||||
"ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS,
|
|
||||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||||
|
|
@ -1150,11 +1045,9 @@ async def update_ldap_config(
|
||||||
# create api key
|
# create api key
|
||||||
@router.post("/api_key", response_model=ApiKey)
|
@router.post("/api_key", response_model=ApiKey)
|
||||||
async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
||||||
if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
|
if not request.app.state.config.ENABLE_API_KEY:
|
||||||
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status.HTTP_403_FORBIDDEN,
|
||||||
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
|
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1172,7 +1065,8 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
||||||
# delete api key
|
# delete api key
|
||||||
@router.delete("/api_key", response_model=bool)
|
@router.delete("/api_key", response_model=bool)
|
||||||
async def delete_api_key(user=Depends(get_current_user)):
|
async def delete_api_key(user=Depends(get_current_user)):
|
||||||
return Users.delete_user_api_key_by_id(user.id)
|
success = Users.update_user_api_key_by_id(user.id, None)
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
# get api key
|
# get api key
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -3,13 +3,10 @@ import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.misc import get_message_list
|
|
||||||
from open_webui.socket.main import get_event_emitter
|
from open_webui.socket.main import get_event_emitter
|
||||||
from open_webui.models.chats import (
|
from open_webui.models.chats import (
|
||||||
ChatForm,
|
ChatForm,
|
||||||
ChatImportForm,
|
ChatImportForm,
|
||||||
ChatUsageStatsListResponse,
|
|
||||||
ChatsImportForm,
|
|
||||||
ChatResponse,
|
ChatResponse,
|
||||||
Chats,
|
Chats,
|
||||||
ChatTitleIdResponse,
|
ChatTitleIdResponse,
|
||||||
|
|
@ -19,6 +16,7 @@ from open_webui.models.folders import Folders
|
||||||
|
|
||||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -27,6 +25,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_permission
|
from open_webui.utils.access_control import has_permission
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -66,132 +65,6 @@ def get_session_user_chat_list(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
|
||||||
# GetChatUsageStats
|
|
||||||
# EXPERIMENTAL: may be removed in future releases
|
|
||||||
############################
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats/usage", response_model=ChatUsageStatsListResponse)
|
|
||||||
def get_session_user_chat_usage_stats(
|
|
||||||
items_per_page: Optional[int] = 50,
|
|
||||||
page: Optional[int] = 1,
|
|
||||||
user=Depends(get_verified_user),
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
limit = items_per_page
|
|
||||||
skip = (page - 1) * limit
|
|
||||||
|
|
||||||
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit)
|
|
||||||
|
|
||||||
chats = result.items
|
|
||||||
total = result.total
|
|
||||||
|
|
||||||
chat_stats = []
|
|
||||||
for chat in chats:
|
|
||||||
messages_map = chat.chat.get("history", {}).get("messages", {})
|
|
||||||
message_id = chat.chat.get("history", {}).get("currentId")
|
|
||||||
|
|
||||||
if messages_map and message_id:
|
|
||||||
try:
|
|
||||||
history_models = {}
|
|
||||||
history_message_count = len(messages_map)
|
|
||||||
history_user_messages = []
|
|
||||||
history_assistant_messages = []
|
|
||||||
|
|
||||||
for message in messages_map.values():
|
|
||||||
if message.get("role", "") == "user":
|
|
||||||
history_user_messages.append(message)
|
|
||||||
elif message.get("role", "") == "assistant":
|
|
||||||
history_assistant_messages.append(message)
|
|
||||||
model = message.get("model", None)
|
|
||||||
if model:
|
|
||||||
if model not in history_models:
|
|
||||||
history_models[model] = 0
|
|
||||||
history_models[model] += 1
|
|
||||||
|
|
||||||
average_user_message_content_length = (
|
|
||||||
sum(
|
|
||||||
len(message.get("content", ""))
|
|
||||||
for message in history_user_messages
|
|
||||||
)
|
|
||||||
/ len(history_user_messages)
|
|
||||||
if len(history_user_messages) > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
average_assistant_message_content_length = (
|
|
||||||
sum(
|
|
||||||
len(message.get("content", ""))
|
|
||||||
for message in history_assistant_messages
|
|
||||||
)
|
|
||||||
/ len(history_assistant_messages)
|
|
||||||
if len(history_assistant_messages) > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
response_times = []
|
|
||||||
for message in history_assistant_messages:
|
|
||||||
user_message_id = message.get("parentId", None)
|
|
||||||
if user_message_id and user_message_id in messages_map:
|
|
||||||
user_message = messages_map[user_message_id]
|
|
||||||
response_time = message.get(
|
|
||||||
"timestamp", 0
|
|
||||||
) - user_message.get("timestamp", 0)
|
|
||||||
|
|
||||||
response_times.append(response_time)
|
|
||||||
|
|
||||||
average_response_time = (
|
|
||||||
sum(response_times) / len(response_times)
|
|
||||||
if len(response_times) > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
message_list = get_message_list(messages_map, message_id)
|
|
||||||
message_count = len(message_list)
|
|
||||||
|
|
||||||
models = {}
|
|
||||||
for message in reversed(message_list):
|
|
||||||
if message.get("role") == "assistant":
|
|
||||||
model = message.get("model", None)
|
|
||||||
if model:
|
|
||||||
if model not in models:
|
|
||||||
models[model] = 0
|
|
||||||
models[model] += 1
|
|
||||||
|
|
||||||
annotation = message.get("annotation", {})
|
|
||||||
|
|
||||||
chat_stats.append(
|
|
||||||
{
|
|
||||||
"id": chat.id,
|
|
||||||
"models": models,
|
|
||||||
"message_count": message_count,
|
|
||||||
"history_models": history_models,
|
|
||||||
"history_message_count": history_message_count,
|
|
||||||
"history_user_message_count": len(history_user_messages),
|
|
||||||
"history_assistant_message_count": len(
|
|
||||||
history_assistant_messages
|
|
||||||
),
|
|
||||||
"average_response_time": average_response_time,
|
|
||||||
"average_user_message_content_length": average_user_message_content_length,
|
|
||||||
"average_assistant_message_content_length": average_assistant_message_content_length,
|
|
||||||
"tags": chat.meta.get("tags", []),
|
|
||||||
"last_message_at": message_list[-1].get("timestamp", None),
|
|
||||||
"updated_at": chat.updated_at,
|
|
||||||
"created_at": chat.created_at,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return ChatUsageStatsListResponse(items=chat_stats, total=total)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# DeleteAllChats
|
# DeleteAllChats
|
||||||
############################
|
############################
|
||||||
|
|
@ -269,15 +142,26 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# ImportChats
|
# ImportChat
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", response_model=list[ChatResponse])
|
@router.post("/import", response_model=Optional[ChatResponse])
|
||||||
async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)):
|
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
|
||||||
try:
|
try:
|
||||||
chats = Chats.import_chats(user.id, form_data.chats)
|
chat = Chats.import_chat(user.id, form_data)
|
||||||
return chats
|
if chat:
|
||||||
|
tags = chat.meta.get("tags", [])
|
||||||
|
for tag_id in tags:
|
||||||
|
tag_id = tag_id.replace(" ", "_").lower()
|
||||||
|
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
|
||||||
|
if (
|
||||||
|
tag_id != "none"
|
||||||
|
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
|
||||||
|
):
|
||||||
|
Tags.insert_new_tag(tag_name, user.id)
|
||||||
|
|
||||||
|
return ChatResponse(**chat.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -344,7 +228,7 @@ async def get_chat_list_by_folder_id(
|
||||||
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
|
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
limit = 10
|
limit = 60
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -774,28 +658,19 @@ async def clone_chat_by_id(
|
||||||
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chats = Chats.import_chats(
|
chat = Chats.import_chat(
|
||||||
user.id,
|
user.id,
|
||||||
[
|
ChatImportForm(
|
||||||
ChatImportForm(
|
**{
|
||||||
**{
|
"chat": updated_chat,
|
||||||
"chat": updated_chat,
|
"meta": chat.meta,
|
||||||
"meta": chat.meta,
|
"pinned": chat.pinned,
|
||||||
"pinned": chat.pinned,
|
"folder_id": chat.folder_id,
|
||||||
"folder_id": chat.folder_id,
|
}
|
||||||
}
|
),
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if chats:
|
return ChatResponse(**chat.model_dump())
|
||||||
chat = chats[0]
|
|
||||||
return ChatResponse(**chat.model_dump())
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -823,28 +698,18 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
"title": f"Clone of {chat.title}",
|
"title": f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chats = Chats.import_chats(
|
chat = Chats.import_chat(
|
||||||
user.id,
|
user.id,
|
||||||
[
|
ChatImportForm(
|
||||||
ChatImportForm(
|
**{
|
||||||
**{
|
"chat": updated_chat,
|
||||||
"chat": updated_chat,
|
"meta": chat.meta,
|
||||||
"meta": chat.meta,
|
"pinned": chat.pinned,
|
||||||
"pinned": chat.pinned,
|
"folder_id": chat.folder_id,
|
||||||
"folder_id": chat.folder_id,
|
}
|
||||||
}
|
),
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
return ChatResponse(**chat.model_dump())
|
||||||
if chats:
|
|
||||||
chat = chats[0]
|
|
||||||
return ChatResponse(**chat.model_dump())
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from open_webui.utils.tools import (
|
||||||
from open_webui.utils.mcp.client import MCPClient
|
from open_webui.utils.mcp.client import MCPClient
|
||||||
from open_webui.models.oauth_sessions import OAuthSessions
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.utils.oauth import (
|
from open_webui.utils.oauth import (
|
||||||
get_discovery_urls,
|
get_discovery_urls,
|
||||||
|
|
@ -31,6 +32,7 @@ from mcp.shared.auth import OAuthMetadata
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -142,7 +144,6 @@ class ToolServerConnection(BaseModel):
|
||||||
path: str
|
path: str
|
||||||
type: Optional[str] = "openapi" # openapi, mcp
|
type: Optional[str] = "openapi" # openapi, mcp
|
||||||
auth_type: Optional[str]
|
auth_type: Optional[str]
|
||||||
headers: Optional[dict | str] = None
|
|
||||||
key: Optional[str]
|
key: Optional[str]
|
||||||
config: Optional[dict]
|
config: Optional[dict]
|
||||||
|
|
||||||
|
|
@ -269,26 +270,18 @@ async def verify_tool_servers_config(
|
||||||
elif form_data.auth_type == "session":
|
elif form_data.auth_type == "session":
|
||||||
token = request.state.token.credentials
|
token = request.state.token.credentials
|
||||||
elif form_data.auth_type == "system_oauth":
|
elif form_data.auth_type == "system_oauth":
|
||||||
oauth_token = None
|
|
||||||
try:
|
try:
|
||||||
if request.cookies.get("oauth_session_id", None):
|
if request.cookies.get("oauth_session_id", None):
|
||||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
token = await request.app.state.oauth_manager.get_oauth_token(
|
||||||
user.id,
|
user.id,
|
||||||
request.cookies.get("oauth_session_id", None),
|
request.cookies.get("oauth_session_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
if oauth_token:
|
|
||||||
token = oauth_token.get("access_token", "")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
if form_data.headers and isinstance(form_data.headers, dict):
|
|
||||||
if headers is None:
|
|
||||||
headers = {}
|
|
||||||
headers.update(form_data.headers)
|
|
||||||
|
|
||||||
await client.connect(form_data.url, headers=headers)
|
await client.connect(form_data.url, headers=headers)
|
||||||
specs = await client.list_tool_specs()
|
specs = await client.list_tool_specs()
|
||||||
return {
|
return {
|
||||||
|
|
@ -306,7 +299,6 @@ async def verify_tool_servers_config(
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
else: # openapi
|
else: # openapi
|
||||||
token = None
|
token = None
|
||||||
headers = None
|
|
||||||
if form_data.auth_type == "bearer":
|
if form_data.auth_type == "bearer":
|
||||||
token = form_data.key
|
token = form_data.key
|
||||||
elif form_data.auth_type == "session":
|
elif form_data.auth_type == "session":
|
||||||
|
|
@ -314,29 +306,15 @@ async def verify_tool_servers_config(
|
||||||
elif form_data.auth_type == "system_oauth":
|
elif form_data.auth_type == "system_oauth":
|
||||||
try:
|
try:
|
||||||
if request.cookies.get("oauth_session_id", None):
|
if request.cookies.get("oauth_session_id", None):
|
||||||
oauth_token = (
|
token = await request.app.state.oauth_manager.get_oauth_token(
|
||||||
await request.app.state.oauth_manager.get_oauth_token(
|
user.id,
|
||||||
user.id,
|
request.cookies.get("oauth_session_id", None),
|
||||||
request.cookies.get("oauth_session_id", None),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if oauth_token:
|
|
||||||
token = oauth_token.get("access_token", "")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if token:
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
if form_data.headers and isinstance(form_data.headers, dict):
|
|
||||||
if headers is None:
|
|
||||||
headers = {}
|
|
||||||
headers.update(form_data.headers)
|
|
||||||
|
|
||||||
url = get_tool_server_url(form_data.url, form_data.path)
|
url = get_tool_server_url(form_data.url, form_data.path)
|
||||||
return await get_tool_server_data(url, headers=headers)
|
return await get_tool_server_data(token, url)
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -461,7 +439,6 @@ async def set_code_execution_config(
|
||||||
############################
|
############################
|
||||||
class ModelsConfigForm(BaseModel):
|
class ModelsConfigForm(BaseModel):
|
||||||
DEFAULT_MODELS: Optional[str]
|
DEFAULT_MODELS: Optional[str]
|
||||||
DEFAULT_PINNED_MODELS: Optional[str]
|
|
||||||
MODEL_ORDER_LIST: Optional[list[str]]
|
MODEL_ORDER_LIST: Optional[list[str]]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -469,7 +446,6 @@ class ModelsConfigForm(BaseModel):
|
||||||
async def get_models_config(request: Request, user=Depends(get_admin_user)):
|
async def get_models_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
||||||
"DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS,
|
|
||||||
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -479,11 +455,9 @@ async def set_models_config(
|
||||||
request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
|
request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
|
request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
|
||||||
request.app.state.config.DEFAULT_PINNED_MODELS = form_data.DEFAULT_PINNED_MODELS
|
|
||||||
request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
|
request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
|
||||||
return {
|
return {
|
||||||
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
||||||
"DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS,
|
|
||||||
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,9 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from open_webui.models.users import Users, UserModel
|
from open_webui.models.users import Users, UserModel
|
||||||
from open_webui.models.feedbacks import (
|
from open_webui.models.feedbacks import (
|
||||||
FeedbackIdResponse,
|
|
||||||
FeedbackModel,
|
FeedbackModel,
|
||||||
FeedbackResponse,
|
FeedbackResponse,
|
||||||
FeedbackForm,
|
FeedbackForm,
|
||||||
FeedbackUserResponse,
|
|
||||||
FeedbackListResponse,
|
|
||||||
Feedbacks,
|
Feedbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -59,16 +56,35 @@ async def update_config(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/all", response_model=list[FeedbackResponse])
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
email: str
|
||||||
|
role: str = "pending"
|
||||||
|
|
||||||
|
last_active_at: int # timestamp in epoch
|
||||||
|
updated_at: int # timestamp in epoch
|
||||||
|
created_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackUserResponse(FeedbackResponse):
|
||||||
|
user: Optional[UserResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
|
||||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||||
feedbacks = Feedbacks.get_all_feedbacks()
|
feedbacks = Feedbacks.get_all_feedbacks()
|
||||||
return feedbacks
|
|
||||||
|
|
||||||
|
feedback_list = []
|
||||||
@router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse])
|
for feedback in feedbacks:
|
||||||
async def get_all_feedback_ids(user=Depends(get_admin_user)):
|
user = Users.get_user_by_id(feedback.user_id)
|
||||||
feedbacks = Feedbacks.get_all_feedbacks()
|
feedback_list.append(
|
||||||
return feedbacks
|
FeedbackUserResponse(
|
||||||
|
**feedback.model_dump(),
|
||||||
|
user=UserResponse(**user.model_dump()) if user else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return feedback_list
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/feedbacks/all")
|
@router.delete("/feedbacks/all")
|
||||||
|
|
@ -78,7 +94,7 @@ async def delete_all_feedbacks(user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
||||||
async def export_all_feedbacks(user=Depends(get_admin_user)):
|
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||||
feedbacks = Feedbacks.get_all_feedbacks()
|
feedbacks = Feedbacks.get_all_feedbacks()
|
||||||
return feedbacks
|
return feedbacks
|
||||||
|
|
||||||
|
|
@ -95,31 +111,6 @@ async def delete_feedbacks(user=Depends(get_verified_user)):
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
PAGE_ITEM_COUNT = 30
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/list", response_model=FeedbackListResponse)
|
|
||||||
async def get_feedbacks(
|
|
||||||
order_by: Optional[str] = None,
|
|
||||||
direction: Optional[str] = None,
|
|
||||||
page: Optional[int] = 1,
|
|
||||||
user=Depends(get_admin_user),
|
|
||||||
):
|
|
||||||
limit = PAGE_ITEM_COUNT
|
|
||||||
|
|
||||||
page = max(1, page)
|
|
||||||
skip = (page - 1) * limit
|
|
||||||
|
|
||||||
filter = {}
|
|
||||||
if order_by:
|
|
||||||
filter["order_by"] = order_by
|
|
||||||
if direction:
|
|
||||||
filter["direction"] = direction
|
|
||||||
|
|
||||||
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/feedback", response_model=FeedbackModel)
|
@router.post("/feedback", response_model=FeedbackModel)
|
||||||
async def create_feedback(
|
async def create_feedback(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
|
||||||
|
|
@ -22,11 +22,10 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
|
|
||||||
from open_webui.models.channels import Channels
|
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
from open_webui.models.files import (
|
from open_webui.models.files import (
|
||||||
FileForm,
|
FileForm,
|
||||||
|
|
@ -34,23 +33,17 @@ from open_webui.models.files import (
|
||||||
FileModelResponse,
|
FileModelResponse,
|
||||||
Files,
|
Files,
|
||||||
)
|
)
|
||||||
from open_webui.models.chats import Chats
|
|
||||||
from open_webui.models.knowledge import Knowledges
|
from open_webui.models.knowledge import Knowledges
|
||||||
from open_webui.models.groups import Groups
|
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
||||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||||
from open_webui.routers.audio import transcribe
|
from open_webui.routers.audio import transcribe
|
||||||
|
|
||||||
from open_webui.storage.provider import Storage
|
from open_webui.storage.provider import Storage
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access
|
|
||||||
from open_webui.utils.misc import strict_match_mime_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -60,48 +53,31 @@ router = APIRouter()
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
# TODO: Optimize this function to use the knowledge_file table for faster lookups.
|
|
||||||
def has_access_to_file(
|
def has_access_to_file(
|
||||||
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
||||||
) -> bool:
|
) -> bool:
|
||||||
file = Files.get_file_by_id(file_id)
|
file = Files.get_file_by_id(file_id)
|
||||||
log.debug(f"Checking if user has {access_type} access to file")
|
log.debug(f"Checking if user has {access_type} access to file")
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the file is associated with any knowledge bases the user has access to
|
has_access = False
|
||||||
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id)
|
|
||||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
|
||||||
for knowledge_base in knowledge_bases:
|
|
||||||
if knowledge_base.user_id == user.id or has_access(
|
|
||||||
user.id, access_type, knowledge_base.access_control, user_group_ids
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
||||||
|
|
||||||
if knowledge_base_id:
|
if knowledge_base_id:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
||||||
user.id, access_type
|
user.id, access_type
|
||||||
)
|
)
|
||||||
for knowledge_base in knowledge_bases:
|
for knowledge_base in knowledge_bases:
|
||||||
if knowledge_base.id == knowledge_base_id:
|
if knowledge_base.id == knowledge_base_id:
|
||||||
return True
|
has_access = True
|
||||||
|
break
|
||||||
|
|
||||||
# Check if the file is associated with any channels the user has access to
|
return has_access
|
||||||
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id)
|
|
||||||
if access_type == "read" and channels:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check if the file is associated with any chats the user has access to
|
|
||||||
# TODO: Granular access control for chats
|
|
||||||
chats = Chats.get_shared_chats_by_file_id(file_id)
|
|
||||||
if chats:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -116,9 +92,17 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
|
||||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||||
)
|
)
|
||||||
|
|
||||||
if strict_match_mime_type(stt_supported_content_types, file.content_type):
|
if any(
|
||||||
|
fnmatch(file.content_type, content_type)
|
||||||
|
for content_type in (
|
||||||
|
stt_supported_content_types
|
||||||
|
if stt_supported_content_types
|
||||||
|
and any(t.strip() for t in stt_supported_content_types)
|
||||||
|
else ["audio/*", "video/webm"]
|
||||||
|
)
|
||||||
|
):
|
||||||
file_path = Storage.get_file(file_path)
|
file_path = Storage.get_file(file_path)
|
||||||
result = transcribe(request, file_path, file_metadata, user)
|
result = transcribe(request, file_path, file_metadata)
|
||||||
|
|
||||||
process_file(
|
process_file(
|
||||||
request,
|
request,
|
||||||
|
|
@ -140,7 +124,6 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
|
||||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||||
)
|
)
|
||||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error processing file: {file_item.id}")
|
log.error(f"Error processing file: {file_item.id}")
|
||||||
Files.update_file_data_by_id(
|
Files.update_file_data_by_id(
|
||||||
|
|
@ -182,7 +165,7 @@ def upload_file_handler(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
background_tasks: Optional[BackgroundTasks] = None,
|
background_tasks: Optional[BackgroundTasks] = None,
|
||||||
):
|
):
|
||||||
log.info(f"file.content_type: {file.content_type} {process}")
|
log.info(f"file.content_type: {file.content_type}")
|
||||||
|
|
||||||
if isinstance(metadata, str):
|
if isinstance(metadata, str):
|
||||||
try:
|
try:
|
||||||
|
|
@ -250,13 +233,6 @@ def upload_file_handler(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "channel_id" in file_metadata:
|
|
||||||
channel = Channels.get_channel_by_id_and_user_id(
|
|
||||||
file_metadata["channel_id"], user.id
|
|
||||||
)
|
|
||||||
if channel:
|
|
||||||
Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id)
|
|
||||||
|
|
||||||
if process:
|
if process:
|
||||||
if background_tasks and process_in_background:
|
if background_tasks and process_in_background:
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from open_webui.models.knowledge import Knowledges
|
||||||
|
|
||||||
|
|
||||||
from open_webui.config import UPLOAD_DIR
|
from open_webui.config import UPLOAD_DIR
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,6 +34,7 @@ from open_webui.utils.access_control import has_permission
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -44,23 +46,7 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[FolderNameIdResponse])
|
@router.get("/", response_model=list[FolderNameIdResponse])
|
||||||
async def get_folders(request: Request, user=Depends(get_verified_user)):
|
async def get_folders(user=Depends(get_verified_user)):
|
||||||
if request.app.state.config.ENABLE_FOLDERS is False:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
if user.role != "admin" and not has_permission(
|
|
||||||
user.id,
|
|
||||||
"features.folders",
|
|
||||||
request.app.state.config.USER_PERMISSIONS,
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
folders = Folders.get_folders_by_user_id(user.id)
|
folders = Folders.get_folders_by_user_id(user.id)
|
||||||
|
|
||||||
# Verify folder data integrity
|
# Verify folder data integrity
|
||||||
|
|
@ -272,10 +258,7 @@ async def update_folder_is_expanded_by_id(
|
||||||
|
|
||||||
@router.delete("/{id}")
|
@router.delete("/{id}")
|
||||||
async def delete_folder_by_id(
|
async def delete_folder_by_id(
|
||||||
request: Request,
|
request: Request, id: str, user=Depends(get_verified_user)
|
||||||
id: str,
|
|
||||||
delete_contents: Optional[bool] = True,
|
|
||||||
user=Depends(get_verified_user),
|
|
||||||
):
|
):
|
||||||
if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
|
if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
|
||||||
chat_delete_permission = has_permission(
|
chat_delete_permission = has_permission(
|
||||||
|
|
@ -294,14 +277,8 @@ async def delete_folder_by_id(
|
||||||
if folder:
|
if folder:
|
||||||
try:
|
try:
|
||||||
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||||
|
|
||||||
for folder_id in folder_ids:
|
for folder_id in folder_ids:
|
||||||
if delete_contents:
|
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
||||||
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
|
||||||
else:
|
|
||||||
Chats.move_chats_by_user_id_and_folder_id(
|
|
||||||
user.id, folder_id, None
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -23,10 +23,12 @@ from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, HttpUrl
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from open_webui.models.users import Users, UserInfoResponse
|
from open_webui.models.users import Users
|
||||||
from open_webui.models.groups import (
|
from open_webui.models.groups import (
|
||||||
Groups,
|
Groups,
|
||||||
GroupForm,
|
GroupForm,
|
||||||
|
|
@ -17,9 +17,11 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -29,18 +31,11 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[GroupResponse])
|
@router.get("/", response_model=list[GroupResponse])
|
||||||
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
|
async def get_groups(user=Depends(get_verified_user)):
|
||||||
|
if user.role == "admin":
|
||||||
filter = {}
|
return Groups.get_groups()
|
||||||
if user.role != "admin":
|
else:
|
||||||
filter["member_id"] = user.id
|
return Groups.get_groups_by_member_id(user.id)
|
||||||
|
|
||||||
if share is not None:
|
|
||||||
filter["share"] = share
|
|
||||||
|
|
||||||
groups = Groups.get_groups(filter=filter)
|
|
||||||
|
|
||||||
return groups
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -53,10 +48,7 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||||
try:
|
try:
|
||||||
group = Groups.insert_new_group(user.id, form_data)
|
group = Groups.insert_new_group(user.id, form_data)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return group
|
||||||
**group.model_dump(),
|
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -79,10 +71,7 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||||
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
group = Groups.get_group_by_id(id)
|
group = Groups.get_group_by_id(id)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return group
|
||||||
**group.model_dump(),
|
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -90,50 +79,6 @@ async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
|
||||||
# ExportGroupById
|
|
||||||
############################
|
|
||||||
|
|
||||||
|
|
||||||
class GroupExportResponse(GroupResponse):
|
|
||||||
user_ids: list[str] = []
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse])
|
|
||||||
async def export_group_by_id(id: str, user=Depends(get_admin_user)):
|
|
||||||
group = Groups.get_group_by_id(id)
|
|
||||||
if group:
|
|
||||||
return GroupExportResponse(
|
|
||||||
**group.model_dump(),
|
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
|
||||||
user_ids=Groups.get_group_user_ids_by_id(group.id),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
|
||||||
# GetUsersInGroupById
|
|
||||||
############################
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/id/{id}/users", response_model=list[UserInfoResponse])
|
|
||||||
async def get_users_in_group(id: str, user=Depends(get_admin_user)):
|
|
||||||
try:
|
|
||||||
users = Users.get_users_by_group_id(id)
|
|
||||||
return users
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error adding users to group {id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# UpdateGroupById
|
# UpdateGroupById
|
||||||
############################
|
############################
|
||||||
|
|
@ -144,12 +89,12 @@ async def update_group_by_id(
|
||||||
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
if form_data.user_ids:
|
||||||
|
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
||||||
|
|
||||||
group = Groups.update_group_by_id(id, form_data)
|
group = Groups.update_group_by_id(id, form_data)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return group
|
||||||
**group.model_dump(),
|
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -178,10 +123,7 @@ async def add_user_to_group(
|
||||||
|
|
||||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
group = Groups.add_users_to_group(id, form_data.user_ids)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return group
|
||||||
**group.model_dump(),
|
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -202,10 +144,7 @@ async def remove_users_from_group(
|
||||||
try:
|
try:
|
||||||
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return group
|
||||||
**group.model_dump(),
|
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,7 @@ from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from open_webui.config import CACHE_DIR
|
from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS
|
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.models.chats import Chats
|
|
||||||
from open_webui.routers.files import upload_file_handler, get_file_content_by_id
|
from open_webui.routers.files import upload_file_handler, get_file_content_by_id
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.headers import include_user_info_headers
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
|
|
@ -33,6 +31,7 @@ from open_webui.utils.images.comfyui import (
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||||
|
|
||||||
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
||||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -45,23 +44,18 @@ def set_image_model(request: Request, model: str):
|
||||||
request.app.state.config.IMAGE_GENERATION_MODEL = model
|
request.app.state.config.IMAGE_GENERATION_MODEL = model
|
||||||
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
|
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
|
||||||
api_auth = get_automatic1111_api_auth(request)
|
api_auth = get_automatic1111_api_auth(request)
|
||||||
|
r = requests.get(
|
||||||
try:
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||||
r = requests.get(
|
headers={"authorization": api_auth},
|
||||||
|
)
|
||||||
|
options = r.json()
|
||||||
|
if model != options["sd_model_checkpoint"]:
|
||||||
|
options["sd_model_checkpoint"] = model
|
||||||
|
r = requests.post(
|
||||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||||
|
json=options,
|
||||||
headers={"authorization": api_auth},
|
headers={"authorization": api_auth},
|
||||||
)
|
)
|
||||||
options = r.json()
|
|
||||||
if model != options["sd_model_checkpoint"]:
|
|
||||||
options["sd_model_checkpoint"] = model
|
|
||||||
r = requests.post(
|
|
||||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
|
||||||
json=options,
|
|
||||||
headers={"authorization": api_auth},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.debug(f"{e}")
|
|
||||||
|
|
||||||
return request.app.state.config.IMAGE_GENERATION_MODEL
|
return request.app.state.config.IMAGE_GENERATION_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -112,10 +106,9 @@ class ImagesConfig(BaseModel):
|
||||||
IMAGES_OPENAI_API_BASE_URL: str
|
IMAGES_OPENAI_API_BASE_URL: str
|
||||||
IMAGES_OPENAI_API_KEY: str
|
IMAGES_OPENAI_API_KEY: str
|
||||||
IMAGES_OPENAI_API_VERSION: str
|
IMAGES_OPENAI_API_VERSION: str
|
||||||
IMAGES_OPENAI_API_PARAMS: Optional[dict | str]
|
|
||||||
|
|
||||||
AUTOMATIC1111_BASE_URL: str
|
AUTOMATIC1111_BASE_URL: str
|
||||||
AUTOMATIC1111_API_AUTH: Optional[dict | str]
|
AUTOMATIC1111_API_AUTH: str
|
||||||
AUTOMATIC1111_PARAMS: Optional[dict | str]
|
AUTOMATIC1111_PARAMS: Optional[dict | str]
|
||||||
|
|
||||||
COMFYUI_BASE_URL: str
|
COMFYUI_BASE_URL: str
|
||||||
|
|
@ -127,7 +120,6 @@ class ImagesConfig(BaseModel):
|
||||||
IMAGES_GEMINI_API_KEY: str
|
IMAGES_GEMINI_API_KEY: str
|
||||||
IMAGES_GEMINI_ENDPOINT_METHOD: str
|
IMAGES_GEMINI_ENDPOINT_METHOD: str
|
||||||
|
|
||||||
ENABLE_IMAGE_EDIT: bool
|
|
||||||
IMAGE_EDIT_ENGINE: str
|
IMAGE_EDIT_ENGINE: str
|
||||||
IMAGE_EDIT_MODEL: str
|
IMAGE_EDIT_MODEL: str
|
||||||
IMAGE_EDIT_SIZE: Optional[str]
|
IMAGE_EDIT_SIZE: Optional[str]
|
||||||
|
|
@ -155,7 +147,6 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
||||||
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
||||||
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
|
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
|
||||||
"IMAGES_OPENAI_API_PARAMS": request.app.state.config.IMAGES_OPENAI_API_PARAMS,
|
|
||||||
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
||||||
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
||||||
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
|
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
|
||||||
|
|
@ -166,7 +157,6 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||||
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||||
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||||
"ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT,
|
|
||||||
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
||||||
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
||||||
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
||||||
|
|
@ -197,12 +187,12 @@ async def update_config(
|
||||||
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
|
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
|
||||||
if (
|
if (
|
||||||
form_data.IMAGE_SIZE == "auto"
|
form_data.IMAGE_SIZE == "auto"
|
||||||
and not form_data.IMAGE_GENERATION_MODEL.startswith("gpt-image")
|
and form_data.IMAGE_GENERATION_MODEL != "gpt-image-1"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=ERROR_MESSAGES.INCORRECT_FORMAT(
|
detail=ERROR_MESSAGES.INCORRECT_FORMAT(
|
||||||
" (auto is only allowed with gpt-image models)."
|
" (auto is only allowed with gpt-image-1)."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -234,9 +224,6 @@ async def update_config(
|
||||||
request.app.state.config.IMAGES_OPENAI_API_VERSION = (
|
request.app.state.config.IMAGES_OPENAI_API_VERSION = (
|
||||||
form_data.IMAGES_OPENAI_API_VERSION
|
form_data.IMAGES_OPENAI_API_VERSION
|
||||||
)
|
)
|
||||||
request.app.state.config.IMAGES_OPENAI_API_PARAMS = (
|
|
||||||
form_data.IMAGES_OPENAI_API_PARAMS
|
|
||||||
)
|
|
||||||
|
|
||||||
request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL
|
request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL
|
||||||
request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
|
request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
|
||||||
|
|
@ -256,16 +243,15 @@ async def update_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Edit Image
|
# Edit Image
|
||||||
request.app.state.config.ENABLE_IMAGE_EDIT = form_data.ENABLE_IMAGE_EDIT
|
|
||||||
request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
|
request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
|
||||||
request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
|
request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
|
||||||
request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
|
request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
|
||||||
|
|
||||||
request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
|
request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
|
||||||
form_data.IMAGES_EDIT_OPENAI_API_BASE_URL
|
form_data.IMAGES_OPENAI_API_BASE_URL
|
||||||
)
|
)
|
||||||
request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
|
request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
|
||||||
form_data.IMAGES_EDIT_OPENAI_API_KEY
|
form_data.IMAGES_OPENAI_API_KEY
|
||||||
)
|
)
|
||||||
request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
|
request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
|
||||||
form_data.IMAGES_EDIT_OPENAI_API_VERSION
|
form_data.IMAGES_EDIT_OPENAI_API_VERSION
|
||||||
|
|
@ -301,7 +287,6 @@ async def update_config(
|
||||||
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
"IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
||||||
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
"IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
||||||
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
|
"IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
|
||||||
"IMAGES_OPENAI_API_PARAMS": request.app.state.config.IMAGES_OPENAI_API_PARAMS,
|
|
||||||
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
||||||
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
||||||
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
|
"AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
|
||||||
|
|
@ -312,7 +297,6 @@ async def update_config(
|
||||||
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||||
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||||
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||||
"ENABLE_IMAGE_EDIT": request.app.state.config.ENABLE_IMAGE_EDIT,
|
|
||||||
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
"IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
|
||||||
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
"IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
|
||||||
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
"IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
|
||||||
|
|
@ -381,7 +365,6 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||||
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
|
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
|
||||||
{"id": "gpt-image-1.5", "name": "GPT-IMAGE 1.5"},
|
|
||||||
]
|
]
|
||||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||||
return [
|
return [
|
||||||
|
|
@ -512,36 +495,15 @@ def upload_image(request, image_data, content_type, metadata, user):
|
||||||
process=False,
|
process=False,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_item and file_item.id:
|
|
||||||
# If chat_id and message_id are provided in metadata, link the file to the chat message
|
|
||||||
chat_id = metadata.get("chat_id")
|
|
||||||
message_id = metadata.get("message_id")
|
|
||||||
|
|
||||||
if chat_id and message_id:
|
|
||||||
Chats.insert_chat_files(
|
|
||||||
chat_id=chat_id,
|
|
||||||
message_id=message_id,
|
|
||||||
file_ids=[file_item.id],
|
|
||||||
user_id=user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||||
return file_item, url
|
return url
|
||||||
|
|
||||||
|
|
||||||
@router.post("/generations")
|
@router.post("/generations")
|
||||||
async def generate_images(
|
|
||||||
request: Request, form_data: CreateImageForm, user=Depends(get_verified_user)
|
|
||||||
):
|
|
||||||
return await image_generations(request, form_data, user=user)
|
|
||||||
|
|
||||||
|
|
||||||
async def image_generations(
|
async def image_generations(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: CreateImageForm,
|
form_data: CreateImageForm,
|
||||||
metadata: Optional[dict] = None,
|
user=Depends(get_verified_user),
|
||||||
user=None,
|
|
||||||
):
|
):
|
||||||
# if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
|
# if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
|
||||||
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
|
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
|
||||||
|
|
@ -558,9 +520,6 @@ async def image_generations(
|
||||||
size = form_data.size
|
size = form_data.size
|
||||||
|
|
||||||
width, height = tuple(map(int, size.split("x")))
|
width, height = tuple(map(int, size.split("x")))
|
||||||
|
|
||||||
metadata = metadata or {}
|
|
||||||
|
|
||||||
model = get_image_model(request)
|
model = get_image_model(request)
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
|
|
@ -575,10 +534,6 @@ async def image_generations(
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
||||||
headers = include_user_info_headers(headers, user)
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
url = f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations"
|
|
||||||
if request.app.state.config.IMAGES_OPENAI_API_VERSION:
|
|
||||||
url = f"{url}?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": form_data.prompt,
|
"prompt": form_data.prompt,
|
||||||
|
|
@ -590,22 +545,21 @@ async def image_generations(
|
||||||
),
|
),
|
||||||
**(
|
**(
|
||||||
{}
|
{}
|
||||||
if request.app.state.config.IMAGE_GENERATION_MODEL.startswith(
|
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
|
||||||
"gpt-image"
|
|
||||||
)
|
|
||||||
else {"response_format": "b64_json"}
|
else {"response_format": "b64_json"}
|
||||||
),
|
),
|
||||||
**(
|
|
||||||
{}
|
|
||||||
if not request.app.state.config.IMAGES_OPENAI_API_PARAMS
|
|
||||||
else request.app.state.config.IMAGES_OPENAI_API_PARAMS
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
api_version_query_param = ""
|
||||||
|
if request.app.state.config.IMAGES_OPENAI_API_VERSION:
|
||||||
|
api_version_query_param = (
|
||||||
|
f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
|
||||||
|
)
|
||||||
|
|
||||||
# Use asyncio.to_thread for the requests.post call
|
# Use asyncio.to_thread for the requests.post call
|
||||||
r = await asyncio.to_thread(
|
r = await asyncio.to_thread(
|
||||||
requests.post,
|
requests.post,
|
||||||
url=url,
|
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
@ -621,9 +575,7 @@ async def image_generations(
|
||||||
else:
|
else:
|
||||||
image_data, content_type = get_image_data(image["b64_json"])
|
image_data, content_type = get_image_data(image["b64_json"])
|
||||||
|
|
||||||
_, url = upload_image(
|
url = upload_image(request, image_data, content_type, data, user)
|
||||||
request, image_data, content_type, {**data, **metadata}, user
|
|
||||||
)
|
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
@ -673,9 +625,7 @@ async def image_generations(
|
||||||
image_data, content_type = get_image_data(
|
image_data, content_type = get_image_data(
|
||||||
image["bytesBase64Encoded"]
|
image["bytesBase64Encoded"]
|
||||||
)
|
)
|
||||||
_, url = upload_image(
|
url = upload_image(request, image_data, content_type, data, user)
|
||||||
request, image_data, content_type, {**data, **metadata}, user
|
|
||||||
)
|
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
elif model.endswith(":generateContent"):
|
elif model.endswith(":generateContent"):
|
||||||
for image in res["candidates"]:
|
for image in res["candidates"]:
|
||||||
|
|
@ -684,12 +634,8 @@ async def image_generations(
|
||||||
image_data, content_type = get_image_data(
|
image_data, content_type = get_image_data(
|
||||||
part["inlineData"]["data"]
|
part["inlineData"]["data"]
|
||||||
)
|
)
|
||||||
_, url = upload_image(
|
url = upload_image(
|
||||||
request,
|
request, image_data, content_type, data, user
|
||||||
image_data,
|
|
||||||
content_type,
|
|
||||||
{**data, **metadata},
|
|
||||||
user,
|
|
||||||
)
|
)
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
|
|
||||||
|
|
@ -739,11 +685,11 @@ async def image_generations(
|
||||||
}
|
}
|
||||||
|
|
||||||
image_data, content_type = get_image_data(image["url"], headers)
|
image_data, content_type = get_image_data(image["url"], headers)
|
||||||
_, url = upload_image(
|
url = upload_image(
|
||||||
request,
|
request,
|
||||||
image_data,
|
image_data,
|
||||||
content_type,
|
content_type,
|
||||||
{**form_data.model_dump(exclude_none=True), **metadata},
|
form_data.model_dump(exclude_none=True),
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
|
|
@ -786,11 +732,11 @@ async def image_generations(
|
||||||
|
|
||||||
for image in res["images"]:
|
for image in res["images"]:
|
||||||
image_data, content_type = get_image_data(image)
|
image_data, content_type = get_image_data(image)
|
||||||
_, url = upload_image(
|
url = upload_image(
|
||||||
request,
|
request,
|
||||||
image_data,
|
image_data,
|
||||||
content_type,
|
content_type,
|
||||||
{**data, "info": res["info"], **metadata},
|
{**data, "info": res["info"]},
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
|
|
@ -817,13 +763,10 @@ class EditImageForm(BaseModel):
|
||||||
async def image_edits(
|
async def image_edits(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: EditImageForm,
|
form_data: EditImageForm,
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
size = None
|
size = None
|
||||||
width, height = None, None
|
width, height = None, None
|
||||||
metadata = metadata or {}
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
request.app.state.config.IMAGE_EDIT_SIZE
|
request.app.state.config.IMAGE_EDIT_SIZE
|
||||||
and "x" in request.app.state.config.IMAGE_EDIT_SIZE
|
and "x" in request.app.state.config.IMAGE_EDIT_SIZE
|
||||||
|
|
@ -875,13 +818,13 @@ async def image_edits(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||||
|
|
||||||
def get_image_file_item(base64_string, param_name="image"):
|
def get_image_file_item(base64_string):
|
||||||
data = base64_string
|
data = base64_string
|
||||||
header, encoded = data.split(",", 1)
|
header, encoded = data.split(",", 1)
|
||||||
mime_type = header.split(";")[0].lstrip("data:")
|
mime_type = header.split(";")[0].lstrip("data:")
|
||||||
image_data = base64.b64decode(encoded)
|
image_data = base64.b64decode(encoded)
|
||||||
return (
|
return (
|
||||||
param_name,
|
"image",
|
||||||
(
|
(
|
||||||
f"{uuid.uuid4()}.png",
|
f"{uuid.uuid4()}.png",
|
||||||
io.BytesIO(image_data),
|
io.BytesIO(image_data),
|
||||||
|
|
@ -906,7 +849,7 @@ async def image_edits(
|
||||||
**({"size": size} if size else {}),
|
**({"size": size} if size else {}),
|
||||||
**(
|
**(
|
||||||
{}
|
{}
|
||||||
if request.app.state.config.IMAGE_EDIT_MODEL.startswith("gpt-image")
|
if "gpt-image-1" in request.app.state.config.IMAGE_EDIT_MODEL
|
||||||
else {"response_format": "b64_json"}
|
else {"response_format": "b64_json"}
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
@ -916,7 +859,7 @@ async def image_edits(
|
||||||
files = [get_image_file_item(form_data.image)]
|
files = [get_image_file_item(form_data.image)]
|
||||||
elif isinstance(form_data.image, list):
|
elif isinstance(form_data.image, list):
|
||||||
for img in form_data.image:
|
for img in form_data.image:
|
||||||
files.append(get_image_file_item(img, "image[]"))
|
files.append(get_image_file_item(img))
|
||||||
|
|
||||||
url_search_params = ""
|
url_search_params = ""
|
||||||
if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
|
if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
|
||||||
|
|
@ -941,9 +884,7 @@ async def image_edits(
|
||||||
else:
|
else:
|
||||||
image_data, content_type = get_image_data(image["b64_json"])
|
image_data, content_type = get_image_data(image["b64_json"])
|
||||||
|
|
||||||
_, url = upload_image(
|
url = upload_image(request, image_data, content_type, data, user)
|
||||||
request, image_data, content_type, {**data, **metadata}, user
|
|
||||||
)
|
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
@ -996,12 +937,8 @@ async def image_edits(
|
||||||
image_data, content_type = get_image_data(
|
image_data, content_type = get_image_data(
|
||||||
part["inlineData"]["data"]
|
part["inlineData"]["data"]
|
||||||
)
|
)
|
||||||
_, url = upload_image(
|
url = upload_image(
|
||||||
request,
|
request, image_data, content_type, data, user
|
||||||
image_data,
|
|
||||||
content_type,
|
|
||||||
{**data, **metadata},
|
|
||||||
user,
|
|
||||||
)
|
)
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
|
|
||||||
|
|
@ -1078,11 +1015,11 @@ async def image_edits(
|
||||||
}
|
}
|
||||||
|
|
||||||
image_data, content_type = get_image_data(image_url, headers)
|
image_data, content_type = get_image_data(image_url, headers)
|
||||||
_, url = upload_image(
|
url = upload_image(
|
||||||
request,
|
request,
|
||||||
image_data,
|
image_data,
|
||||||
content_type,
|
content_type,
|
||||||
{**form_data.model_dump(exclude_none=True), **metadata},
|
form_data.model_dump(exclude_none=True),
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
images.append({"url": url})
|
images.append({"url": url})
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue