Merge branch 'open-webui:main' into main

This commit is contained in:
Henry 2025-06-04 15:37:33 +02:00 committed by GitHub
commit 415114504f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
478 changed files with 64194 additions and 22041 deletions

View file

@ -1,80 +0,0 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
# Bug Report
## Important Notes
- **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If youre unsure, start a discussion post first. This will help us efficiently focus on improving the project.
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. Were here to help if youre open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
---
## Installation Method
[Describe the method you used to install the project, e.g., git clone, Docker, pip, etc.]
## Environment
- **Open WebUI Version:** [e.g., v0.3.11]
- **Ollama (if applicable):** [e.g., v0.2.0, v0.1.32-rc1]
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
**Confirmation:**
- [ ] I have read and followed all the instructions provided in the README.md.
- [ ] I am on the latest version of both Open WebUI and Ollama.
- [ ] I have included the browser console logs.
- [ ] I have included the Docker container logs.
- [ ] I have provided the exact steps to reproduce the bug in the "Steps to Reproduce" section below.
## Expected Behavior:
[Describe what you expected to happen.]
## Actual Behavior:
[Describe what actually happened.]
## Description
**Bug Summary:**
[Provide a brief but clear summary of the bug]
## Reproduction Details
**Steps to Reproduce:**
[Outline the steps to reproduce the bug. Be as detailed as possible.]
## Logs and Screenshots
**Browser Console Logs:**
[Include relevant browser console logs, if applicable]
**Docker Container Logs:**
[Include relevant Docker container logs, if applicable]
**Screenshots/Screen Recordings (if applicable):**
[Attach any relevant screenshots to help illustrate the issue]
## Additional Information
[Include any additional details that may help in understanding and reproducing the issue. This could include specific configurations, error messages, or anything else relevant to the bug.]
## Note
If the bug report is incomplete or does not follow the provided instructions, it may not be addressed. Please ensure that you have followed the steps outlined in the README.md and troubleshooting.md documents, and provide all necessary information for us to reproduce and address the issue. Thank you!

167
.github/ISSUE_TEMPLATE/bug_report.yaml vendored Normal file
View file

@ -0,0 +1,167 @@
name: Bug Report
description: Create a detailed bug report to help us improve Open WebUI.
title: 'issue: '
labels: ['bug', 'triage']
assignees: []
body:
- type: markdown
attributes:
value: |
# Bug Report
## Important Notes
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project.
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
- **Contributing**: If you encounter an issue, consider submitting a pull request or forking the project. We prioritize preventing contributor burnout to maintain Open WebUI's quality.
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
- type: checkboxes
id: issue-check
attributes:
label: Check Existing Issues
description: Confirm that youve checked for existing reports before submitting a new one.
options:
- label: I have searched the existing issues and discussions.
required: true
- label: I am using the latest version of Open WebUI.
required: true
- type: dropdown
id: installation-method
attributes:
label: Installation Method
description: How did you install Open WebUI?
options:
- Git Clone
- Pip Install
- Docker
- Other
validations:
required: true
- type: input
id: open-webui-version
attributes:
label: Open WebUI Version
description: Specify the version (e.g., v0.3.11)
validations:
required: true
- type: input
id: ollama-version
attributes:
label: Ollama Version (if applicable)
description: Specify the version (e.g., v0.2.0, or v0.1.32-rc1)
validations:
required: false
- type: input
id: operating-system
attributes:
label: Operating System
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04)
validations:
required: true
- type: input
id: browser
attributes:
label: Browser (if applicable)
description: Specify the browser/version (e.g., Chrome 100.0, Firefox 98.0)
validations:
required: false
- type: checkboxes
id: confirmation
attributes:
label: Confirmation
description: Ensure the following prerequisites have been met.
options:
- label: I have read and followed all instructions in `README.md`.
required: true
- label: I am using the latest version of **both** Open WebUI and Ollama.
required: true
- label: I have included the browser console logs.
required: true
- label: I have included the Docker container logs.
required: true
- label: I have **provided every relevant configuration, setting, and environment variable used in my setup.**
required: true
- label: I have clearly **listed every relevant configuration, custom setting, environment variable, and command-line option that influences my setup** (such as Docker Compose overrides, .env values, browser settings, authentication configurations, etc).
required: true
- label: |
I have documented **step-by-step reproduction instructions that are precise, sequential, and leave nothing to interpretation**. My steps:
- Start with the initial platform/version/OS and dependencies used,
- Specify exact install/launch/configure commands,
- List URLs visited, user input (incl. example values/emails/passwords if needed),
- Describe all options and toggles enabled or changed,
- Include any files or environmental changes,
- Identify the expected and actual result at each stage,
- Ensure any reasonably skilled user can follow and hit the same issue.
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected Behavior
description: Describe what should have happened.
validations:
required: true
- type: textarea
id: actual-behavior
attributes:
label: Actual Behavior
description: Describe what actually happened.
validations:
required: true
- type: textarea
id: reproduction-steps
attributes:
label: Steps to Reproduce
description: |
Please provide a **very detailed, step-by-step guide** to reproduce the issue. Your instructions should be so clear and precise that anyone can follow them without guesswork. Include every relevant detail—settings, configuration options, exact commands used, values entered, and any prerequisites or environment variables.
**If full reproduction steps and all relevant settings are not provided, your issue may not be addressed.**
placeholder: |
Example (include every detail):
1. Start with a clean Ubuntu 22.04 install.
2. Install Docker v24.0.5 and start the service.
3. Clone the Open WebUI repo (git clone ...).
4. Use the Docker Compose file without modifications.
5. Open browser Chrome 115.0 in incognito mode.
6. Go to http://localhost:8080 and log in with user "test@example.com".
7. Set the language to "English" and theme to "Dark".
8. Attempt to connect to Ollama at "http://localhost:11434".
9. Observe that the error message "Connection refused" appears at the top right.
Please list each step carefully and include all relevant configuration, settings, and options.
validations:
required: true
- type: textarea
id: logs-screenshots
attributes:
label: Logs & Screenshots
description: Include relevant logs, errors, or screenshots to help diagnose the issue.
placeholder: 'Attach logs from the browser console, Docker logs, or error messages.'
validations:
required: true
- type: textarea
id: additional-info
attributes:
label: Additional Information
description: Provide any extra details that may assist in understanding the issue.
validations:
required: false
- type: markdown
attributes:
value: |
## Note
If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
Thank you for contributing to Open WebUI!

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View file

@ -0,0 +1 @@
blank_issues_enabled: false

View file

@ -1,35 +0,0 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
# Feature Request
## Important Notes
- **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If youre unsure, start a discussion post first. This will help us efficiently focus on improving the project.
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. Were here to help if youre open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View file

@ -0,0 +1,64 @@
name: Feature Request
description: Suggest an idea for this project
title: 'feat: '
labels: ['triage']
body:
- type: markdown
attributes:
value: |
## Important Notes
### Before submitting
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
It's likely we're already tracking it! If youre unsure, start a discussion post first.
This will help us efficiently focus on improving the project.
### Collaborate respectfully
We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**.
Remember:
- Open WebUI is a **volunteer-driven project**
- It's managed by a **single maintainer**
- It's supported by contributors who also have **full-time jobs**
We appreciate your time and ask that you **respect ours**.
### Contributing
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
### Bug reproducibility
If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a `pip install` with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "[issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
- type: checkboxes
id: existing-issue
attributes:
label: Check Existing Issues
description: Please confirm that you've checked for existing similar requests
options:
- label: I have searched the existing issues and discussions.
required: true
- type: textarea
id: problem-description
attributes:
label: Problem Description
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
placeholder: "Ex. I'm always frustrated when..."
validations:
required: true
- type: textarea
id: solution-description
attributes:
label: Desired Solution you'd like
description: Clearly describe what you want to happen.
validations:
required: true
- type: textarea
id: alternatives-considered
attributes:
label: Alternatives Considered
description: A clear and concise description of any alternative solutions or features you've considered.
- type: textarea
id: additional-context
attributes:
label: Additional Context
description: Add any other context or screenshots about the feature request here.

View file

@ -1,12 +1,26 @@
version: 2 version: 2
updates: updates:
- package-ecosystem: uv
directory: '/'
schedule:
interval: monthly
target-branch: 'dev'
- package-ecosystem: pip - package-ecosystem: pip
directory: '/backend' directory: '/backend'
schedule: schedule:
interval: monthly interval: monthly
target-branch: 'dev' target-branch: 'dev'
- package-ecosystem: npm
directory: '/'
schedule:
interval: monthly
target-branch: 'dev'
- package-ecosystem: 'github-actions' - package-ecosystem: 'github-actions'
directory: '/' directory: '/'
schedule: schedule:
# Check for updates to GitHub Actions every week # Check for updates to GitHub Actions every week
interval: monthly interval: monthly
target-branch: 'dev'

View file

@ -9,9 +9,9 @@
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description. - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources? - [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation? - [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
- [ ] **Testing:** Have you written and run sufficient tests for validating the changes? - [ ] **Testing:** Have you written and run sufficient tests to validate the changes?
- [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards? - [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
- [ ] **Prefix:** To cleary categorize this pull request, prefix the pull request title, using one of the following: - [ ] **Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following:
- **BREAKING CHANGE**: Significant changes that may affect compatibility - **BREAKING CHANGE**: Significant changes that may affect compatibility
- **build**: Changes that affect the build system or external dependencies - **build**: Changes that affect the build system or external dependencies
- **ci**: Changes to our continuous integration processes or workflows - **ci**: Changes to our continuous integration processes or workflows
@ -22,7 +22,7 @@
- **i18n**: Internationalization or localization changes - **i18n**: Internationalization or localization changes
- **perf**: Performance improvement - **perf**: Performance improvement
- **refactor**: Code restructuring for better maintainability, readability, or scalability - **refactor**: Code restructuring for better maintainability, readability, or scalability
- **style**: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc.) - **style**: Changes that do not affect the meaning of the code (white space, formatting, missing semi-colons, etc.)
- **test**: Adding missing tests or correcting existing tests - **test**: Adding missing tests or correcting existing tests
- **WIP**: Work in progress, a temporary label for incomplete or ongoing work - **WIP**: Work in progress, a temporary label for incomplete or ongoing work
@ -70,3 +70,7 @@
### Screenshots or Videos ### Screenshots or Videos
- [Attach any relevant screenshots or videos demonstrating the changes] - [Attach any relevant screenshots or videos demonstrating the changes]
### Contributor License Agreement
By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms.

View file

@ -14,7 +14,7 @@ env:
jobs: jobs:
build-main-image: build-main-image:
runs-on: ubuntu-latest runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions: permissions:
contents: read contents: read
packages: write packages: write
@ -111,7 +111,7 @@ jobs:
retention-days: 1 retention-days: 1
build-cuda-image: build-cuda-image:
runs-on: ubuntu-latest runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions: permissions:
contents: read contents: read
packages: write packages: write
@ -211,7 +211,7 @@ jobs:
retention-days: 1 retention-days: 1
build-ollama-image: build-ollama-image:
runs-on: ubuntu-latest runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions: permissions:
contents: read contents: read
packages: write packages: write

View file

@ -5,10 +5,18 @@ on:
branches: branches:
- main - main
- dev - dev
paths:
- 'backend/**'
- 'pyproject.toml'
- 'uv.lock'
pull_request: pull_request:
branches: branches:
- main - main
- dev - dev
paths:
- 'backend/**'
- 'pyproject.toml'
- 'uv.lock'
jobs: jobs:
build: build:
@ -17,7 +25,9 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: [3.11] python-version:
- 3.11.x
- 3.12.x
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@ -25,7 +35,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: '${{ matrix.python-version }}'
- name: Install dependencies - name: Install dependencies
run: | run: |

View file

@ -5,10 +5,18 @@ on:
branches: branches:
- main - main
- dev - dev
paths-ignore:
- 'backend/**'
- 'pyproject.toml'
- 'uv.lock'
pull_request: pull_request:
branches: branches:
- main - main
- dev - dev
paths-ignore:
- 'backend/**'
- 'pyproject.toml'
- 'uv.lock'
jobs: jobs:
build: build:
@ -21,7 +29,7 @@ jobs:
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: '22' # Or specify any other version you want to use node-version: '22'
- name: Install Dependencies - name: Install Dependencies
run: npm install run: npm install

View file

@ -17,9 +17,13 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install Git
run: sudo apt-get update && sudo apt-get install -y git
- uses: actions/setup-node@v4 - uses: actions/setup-node@v4
with: with:
node-version: 18 node-version: 22
- uses: actions/setup-python@v5 - uses: actions/setup-python@v5
with: with:
python-version: 3.11 python-version: 3.11

View file

@ -5,6 +5,497 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.6.13] - 2025-05-30
### Added
- 🟦 **Azure OpenAI Embedding Support**: You can now select Azure OpenAI endpoints for text embeddings, unlocking seamless integration with enterprise-scale Azure AI for powerful RAG and knowledge workflows—no more workarounds, connect and scale effortlessly.
- 🧩 **Smarter Custom Parameter Handling**: Instantly enjoy more flexible model setup—any JSON pasted into custom parameter fields is now parsed automatically, so you can define rich, nested parameters without tedious manual adjustment. This streamlines advanced configuration for all models and accelerates experimentation.
- ⚙️ **General Backend Refactoring**: Significant backend improvements deliver a cleaner codebase for better maintainability, faster performance, and even greater platform reliability—making all your workflows run more smoothly.
- 🌏 **Localization Upgrades**: Experience highly improved user interface translations and clarity in Simplified, Traditional Chinese, Korean, and Finnish, offering a more natural, accurate, and accessible experience for global users.
### Fixed
- 🛡️ **Robust Message Handling on Chat Load**: Fixed an issue where chat pages could fail to load if a referenced message was missing or undefined; now, chats always load smoothly and missing IDs no longer disrupt your workflow.
- 📝 **Correct Prompt Access Control**: Ensured that the prompt access controls register properly, restoring reliable permissioning and safeguarding your prompt workflows.
- 🛠 **Open WebUI-Specific Params No Longer Sent to Models**: Fixed a bug that sent internal WebUI parameters to APIs, ensuring only intended model options are transmitted—restoring predictable, error-free model operation.
- 🧠 **Refined Memory Error Handling**: Enhanced stability during memory-related operations, so even uncommon memory errors are gracefully managed without disrupting your session—resulting in a more reliable, worry-free experience.
## [0.6.12] - 2025-05-29
### Added
- 🧩 **Custom Advanced Model Parameters**: You can now add your own tailor-made advanced parameters to any model, empowering you to fine-tune behavior and unlock greater flexibility beyond just the built-in options—accelerate your experimentation.
- 🪧 **Datalab Marker API Content Extraction Support**: Seamlessly extract content from files and documents using the Datalab Marker API directly in your workflows, enabling more robust structured data extraction for RAG and document processing with just a simple engine switch in the UI.
- ⚡ **Parallelized Base Model Fetching**: Experience noticeably faster startup and model refresh times—base model data now loads in parallel, drastically shortening delays in busy or large-scale deployments.
- 🧠 **Efficient Function Loading and Caching**: Functions are now only reloaded if their content changes, preventing unnecessary duplicate loads, saving bandwidth, and boosting performance.
- 🌍 **Localization & Translation Enhancements**: Improved and expanded Simplified, Traditional Chinese, and Russian translations, providing smoother, more accurate, and context-aware experiences for global users.
### Fixed
- 💬 **Stable Message Input Box**: Fixed an issue where the message input box would shift unexpectedly (especially on mobile or with screen reader support), ensuring a smooth and reliable typing experience for every user.
- 🔊 **Reliable Read Aloud (Text-to-Speech)**: Read aloud now works seamlessly across messages, so users depending on TTS for accessibility or multitasking will experience uninterrupted and clear voice playback.
- 🖼 **Image Preview and Download Restored**: Fixed problems with image preview and downloads, ensuring frictionless creation, previewing, and downloading of images in your chats—no more interruptions in creative or documentation workflows.
- 📱 **Improved Mobile Styling for Workspace Capabilities**: Capabilities management is now readable and easy-to-use even on mobile devices, empowering admins and users to manage access quickly on the go.
- 🔁 **/api/v1/retrieval/query/collection Endpoint Reliability**: Queries to retrieval collections now return the expected results, bolstering the reliability of your knowledge workflows and citation-ready responses.
### Removed
- 🧹 **Duplicate CSS Elements**: Streamlined the UI by removing redundant CSS, reducing clutter and improving load times for a smoother visual experience.
## [0.6.11] - 2025-05-27
### Added
- 🟢 **Ollama Model Status Indicator in Model Selector**: Instantly see which Ollama models are currently loaded with a clear indicator in the model selector, helping you stay organized and optimize local model usage.
- 🗑️ **Unload Ollama Model Directly from Model Selector**: Easily release memory and resources by unloading any loaded Ollama model right in the model selector—streamline hardware management without switching pages.
- 🗣️ **User-Configurable Speech-to-Text Language Setting**: Improve transcription accuracy by letting individual users explicitly set their preferred STT language in their settings—ideal for multilingual teams and clear audio capture.
- ⚡ **Granular Audio Playback Speed Control**: Instead of just presets, you can now choose granular audio speed using a numeric input, giving you complete control over playback pace in transcriptions and media reviews.
- 📦 **GZip, Brotli, ZStd Compression Middleware**: Enjoy significantly faster page loads and reduced bandwidth usage with new server-side compression—giving users a snappier, more efficient experience.
- 🏷️ **Configurable Weight for BM25 in Hybrid Search**: Fine-tune search relevance by adjusting the weight for BM25 inside hybrid search from the UI, letting you tailor knowledge search results to your workflow.
- 🧪 **Bypass File Creation with CTRL + SHIFT + V**: When “Paste Large Text as File” is enabled, use CTRL + SHIFT + V to skip the file creation dialog and instantly upload text as a file—perfect for rapid document prep.
- 🌐 **Bypass Web Loader in Web Search**: Choose to bypass web content loading and use snippets directly in web search for faster, more reliable results when page loads are slow or blocked.
- 🚀 **Environment Variable: WEBUI_AUTH_TRUSTED_GROUPS_HEADER**: Now sync and manage user groups directly via trusted HTTP header, unlocking smoother single sign-on and identity integrations for organizations.
- 🏢 **Workspace Models Visibility Controls**: You can now hide workspace-level models from both the model selector and shared environments—keep your team focused and reduce clutter from rarely-used endpoints.
- 🛡️ **Copy Model Link**: You can now copy a direct link to any model—including those hidden from the selector—making sharing and onboarding others more seamless.
- 🔗 **Load Function Directly from URL**: Simplify custom function management—just paste any GitHub function URL into Open WebUI and import new functions in seconds.
- ⚙️ **Custom Name/Description for External Tool Servers**: Personalize and clarify external tool servers by assigning custom names and descriptions, making it easier to manage integrations in large-scale workspaces.
- 🌍 **Custom OpenAPI JSON URL Support for Tool Servers**: Supports specifying any custom OpenAPI JSON URL, unlocking more flexible integration with any backend for tool calls.
- 📊 **Source Field Now Displays in Non-Streaming Responses with Attachments**: When files or knowledge are attached, the "source" field now appears for all responses, even in non-streaming mode—enabling improved citation workflow.
- 🎛 **Pinned Chats**: Reduced payload size on pinned chat requests—leading to faster load times and less data usage, especially on busy warehouses.
- 🛠 **Import/Export Default Prompt Suggestions**: Enjoy one-click import/export of prompt suggestions, making it much easier to share, reuse, and manage best practices across teams or deployments.
- 🍰 **Banners Now Sortable from Admin Settings**: Quickly re-order or prioritize banners, letting you highlight the most critical info for your team.
- 🛠 **Advanced Chat Parameters—Clearer Ollama Support Labels**: Parameters and advanced settings now explicitly indicate if they are Ollama-specific, reducing confusion and improving setup accuracy.
- 🤏 **Scroll Bar Thumb Improved for Better Visibility**: Enhanced scrollbar styling makes navigation more accessible and visually intuitive.
- 🗄️ **Modal Redesign for Archived and User Chat Listings**: Clean, modern modal interface for browsing archived and user-specific chats makes locating conversations faster and more pleasant.
- 📝 **Add/Edit Memory Modal UX**: Memory modals are now larger and have resizable input fields, supporting easier editing of long or complex memory content.
- 🏆 **Translation & Localization Enhancements**: Major upgrades to Chinese (Simplified & Traditional), Korean, Russian, German, Danish, Finnish—not just fixing typos, but consistency, tone, and terminology for a more natural native-language experience.
- ⚡ **General Backend Stability & Security Enhancements**: Various backend refinements ensure a more resilient, reliable, and secure platform for smoother operation and peace of mind.
### Fixed
- 🖼️ **Image Generation with Allowed File Extensions Now Works Reliably**: Ensure seamless image generation even when strict file extension rules are set—no more blocked creative workflows due to technical hiccups.
- 🗂 **Remove Leading Dot for File Extension Check**: Fixed an issue where file validation failed because of a leading dot, making file uploads and knowledge management more robust.
- 🏷️ **Correct Local/External Model Classification**: The platform now accurately distinguishes between local and external models—preventing local models from showing up as external (and vice versa)—ensuring seamless setup, clarity, and management of your AI model endpoints.
- 📄 **External Document Loader Now Functions as Intended**: External document loaders are reliably invoked, ensuring smoother knowledge ingestion from external sources—expanding your RAG and knowledge workflows.
- 🎯 **Correct Handling of Toggle Filters**: Toggle filters are now robustly managed, preventing accidental auto-activation and ensuring user preferences are always respected.
- 🗃 **S3 Tagging Character Restrictions Fixed**: Tags for files in S3 now automatically meet Amazons allowed character set, avoiding upload errors and ensuring cross-cloud compatibility.
- 🛡️ **Authentication Now Uses Password Hash When Duplicate Emails Exist**: Ensures account security and prevents access issues if duplicate emails are present in your system.
### Changed
- 🧩 **Admin Settings: OAuth Redirects Now Use WEBUI_URL**: The OAuth redirect URL is now based on the explicitly set WEBUI_URL, ensuring single sign-on and identity provider integrations always send users to the correct frontend.
### Removed
- 💡 **Duplicate/Typo Component Removals**: Obsolete components have been cleaned up, reducing confusion and improving overall code quality for the team.
- 🚫 **Streaming Upsert in Pinecone Removed**: Removed streaming upsert references for better compatibility and future-proofing with latest Pinecone SDK updates.
## [0.6.10] - 2025-05-19
### Added
- 🧩 **Experimental Azure OpenAI Support**: Instantly connect to Azure OpenAI endpoints by simply pasting your Azure OpenAI URL into the model connections—bringing flexible, enterprise-grade AI integration directly to your workflow.
- 💧 **Watermark AI Responses**: Easily add a visible watermark to AI-generated responses for clear content provenance and compliance with EU AI regulations—perfect for regulated environments and transparent communication.
- 🔍 **Enhanced Search Experience with Dedicated Modal**: Enjoy a modern, powerful search UI in a dedicated modal (open with Ctrl/Cmd + K) accessible from anywhere—quickly find chats, models, or content and boost your productivity.
- 🔲 **"Toggle" Filter Type for Chat Input**: Add interactive toggle filters (e.g. Web Search, Image, Code Interpreter) right into the chat input—giving you one-click control to activate features and simplifying the chat configuration process.
- 🧰 **Granular Model Capabilities Editor**: Define detailed capabilities and feature access for each AI model directly from the model editor—enabling tailored behavior, modular control, and a more customized AI environment for every team or use case.
- 🌐 **Flexible Local and External Connection Support**: Now you can designate any AI connection—whether OpenAI, Ollama, or others—as local or external, enabling seamless setup for on-premise, self-hosted, or cloud configurations and giving you maximum control and flexibility.
- 🗂️ **Allowed File Extensions for RAG**: Gain full control over your Retrieval-Augmented Generation (RAG) workflows by specifying exactly which file extensions are permitted for upload, improving security and relevance of indexed documents.
- 🔊 **Enhanced Audio Transcription Logic**: Experience smoother, more reliable audio transcription—very long audio files are now automatically split and processed in segments, preventing errors and ensuring even challenging files are transcribed seamlessly, all part of a broader stability upgrade for robust media workflows.
- 🦾 **External Document Loader Support**: Enhance knowledge base building by integrating documents using external loaders from a wide range of data sources, expanding what your AI can read and process.
- 📝 **Preview Button for Code Artifacts**: Instantly jump from an HTML code block to its associated artifacts page with the click of a new preview button—speeding up review and streamlining analysis.
- 🦻 **Screen Reader Support for Response Messages**: All chat responses are now fully compatible with screen readers, making the platform more inclusive and accessible for everyone.
- 🧑‍💼 **Customizable Pending User Overlay**: You can now tailor the overlay title and content shown to pending users, ensuring onboarding messaging is perfectly aligned with your organizations tone and requirements.
- 🔐 **Option to Disable LDAP Certificate Validation**: You can now disable LDAP certificate validation for maximum flexibility in diverse IT environments—making integrations and troubleshooting much easier.
- 🎯 **Workspace Search by Username or Email**: Easily search across workspace pages using any username or email address, streamlining user and resource management.
- 🎨 **High Contrast & Dark Mode Enhancements**: Further improved placeholder, input, suggestion, toast, and model selector contrasts—including a dedicated placeholder dark mode—for more comfortable viewing in all lighting conditions.
- 🛡️ **Refined Security for Pipelines & Model Uploads**: Strengthened safeguards against path traversal vulnerabilities during uploads—ensuring your platforms document and model management remains secure.
- 🌏 **Major Localization Upgrades**: Comprehensive translation updates and improvements across Korean, Bulgarian, Catalan, Japanese, Italian, Traditional Chinese, and Spanish—including more accurate AI terminology for clarity; your experience is now more natural, inclusive, and professional everywhere.
- 🦾 **General Backend Stability & Security**: Multiple backend improvements (including file upload, command navigation, and logging refactorings) deliver increased resilience, better error handling, and a more robust platform for all users.
### Fixed
- ✅ **Evaluation Feedback Endpoint Reliability**: Addressed issues with feedback submission endpoints to ensure seamless user feedback collection on model responses.
- 🫰 **Model List State Fixes**: Resolved issues where model status toggles in the workspace page might inadvertently switch or confuse state, making the management of active/inactive models more dependable.
- ✍️ **Admin Signup Logic**: Admin self-signup experience and validation flow is smoother and more robust.
- 🔁 **Signout Redirect Flow Improved**: Logging out now redirects more reliably, reducing confusion and making session management seamless.
## [0.6.9] - 2025-05-10
### Added
- 📝 **Edit Attached Images/Files in Messages**: You can now easily edit your sent messages by removing attached files—streamlining document management, correcting mistakes on the fly, and keeping your chats clutter-free.
- 🚨 **Clear Alerts for Private Task Models**: When interacting with private task models, the UI now clearly alerts you—making it easier to understand resource availability and access, reducing confusion during workflow setup.
### Fixed
- 🛡️ **Confirm Dialog Focus Trap Reliability**: The focus now stays correctly within confirmation dialogs, ensuring keyboard navigation and accessibility is seamless and preventing accidental operations—especially helpful during critical or rapid workflows.
- 💬 **Temporary Chat Admin Controls & Session Cleanliness**: Admins are now able to properly enable temporary chat mode without errors, and previous session prompts or tool selections no longer carry over—delivering a fresh, predictable, and consistent temporary chat experience every time.
- 🤖 **External Reranker Integration Functionality Restored**: External reranker integrations now work correctly, allowing you to fully leverage advanced ranking services for sharper, more relevant search results in your RAG and knowledge base workflows.
## [0.6.8] - 2025-05-10
### Added
- 🏆 **External Reranker Support for Knowledge Base Search**: Supercharge your Retrieval-Augmented Generation (RAG) workflows with the new External Reranker integration; easily plug in advanced reranking services via the UI to deliver sharper and more relevant search results, accelerating research and insight discovery.
- 📤 **Unstylized PDF Export Option (Reduced File Size)**: When exporting chat transcripts or documents, you can now choose an unstylized PDF export for snappier downloads, minimal file size, and clean data archiving—perfect for large-scale storage or sharing.
- 📝 **Vazirmatn Font for Persian & Arabic**: Arabic and Persian users will now see their text beautifully rendered with the specialized Vazirmatn font for an improved localized reading experience.
- 🏷️ **SharePoint Tenant ID Support for OneDrive**: You can now specify a SharePoint tenant ID in OneDrive settings for seamless authentication and granular enterprise integration.
- 👤 **Refresh OAuth Profile Picture**: Your OAuth profile picture now updates in real-time, ensuring your presence and avatar always match your latest identity across integrated platforms.
- 🔧 **Milvus Configuration Improvements**: Configure index and metric types for Milvus directly within settings; take full control of your vector database for more accurate and robust AI search experiences.
- 🛡️ **S3 Tagging Toggle for Compatibility**: Optional S3 tagging via an environment toggle grants full compatibility with all storage backends—including those that dont support tagging like Cloudflare R2—ensuring error-free attachment and document management.
- 👨‍🦯 **Icon Button Accessibility Improvements**: Key interactive icon-buttons now include aria-labels and ARIA descriptions, so screen readers provide precise guidance about what action each button performs for improved accessibility.
- ♿ **Enhanced Accessibility with Modal Focus Trap**: Modal dialogs and pop-ups now feature a focus trap and improved ARIA roles, ensuring seamless navigation and screen reader support—making the interface friendlier for everyone, including keyboard and assistive tech users.
- 🏃 **Improved Admin User List Loading Indicator**: The user list loading experience is now clearer and more responsive in the admin panel.
- 🧑‍🤝‍🧑 **Larger Admin User List Page Size**: Admins can now manage up to 30 users per page in the admin interface, drastically reducing pagination and making large user teams easier and faster to manage.
- 🌠 **Default Code Interpreter Prompt Clarified**: The built-in code interpreter prompt is now more explicit, preventing AI from wrapping code in Markdown blocks when not needed—ensuring properly formatted code runs as intended every time.
- 🧾 **Improved Default Title Generation Prompt Template**: Title generation now uses a robust template for reliable JSON output, improving chat organization and searchability.
- 🔗 **Support Jupyter Notebooks with Non-Root Base URLs**: Notebook-based code execution now supports non-root deployed Jupyter servers, granting full flexibility for hybrid or multi-user setups.
- 📰 **UI Scrollbar Always Visible for Overflow Tools**: When available tools overflow the display, the scrollbar is now always visible and theres a handy "show all" toggle, making navigation of large toolsets snappier and more intuitive.
- 🛠️ **General Backend Refactoring for Stability**: Multiple under-the-hood improvements have been made across backend components, ensuring smoother performance, fewer errors, and a more reliable overall experience for all users.
- 🚀 **Optimized Web Search for Faster Results**: Web search speed and performance have been significantly enhanced, delivering answers and sources in record time to accelerate your research-heavy workflows.
- 💡 **More Supported Languages**: Expanded language support ensures an even wider range of users can enjoy an intuitive and natural interface in their native tongue.
### Fixed
- 🏃‍♂️ **Exhausting Workers in Nginx Reverse Proxy Due to Websocket Fix**: Websocket sessions are now fully compatible behind Nginx, eliminating worker exhaustion and restoring 24/7 reliability for real-time chats even in complex deployments.
- 🎤 **Audio Transcription Issue with OpenAI Resolved**: OpenAI-based audio transcription now handles WebM and newer formats without error, ensuring seamless voice-to-text workflows every time.
- 👉 **Message Input RTL Issue Fixed**: The chat message input now displays correctly for right-to-left languages, creating a flawless typing and reading experience for Arabic, Hebrew, and more.
- 🀄 **Katex: Proper Rendering of Chinese Characters Next to Math**: Math formulas now render perfectly even when directly adjacent to Chinese (CJK) characters, improving visual clarity for multilingual teams and cross-language documents.
- 🔂 **Duplicate Web Search URLs Eliminated**: Search results now reliably filter out URL duplicates, so your knowledge and search citations are always clean, trimmed, and easy to review.
- 📄 **Markdown Rendering Fixed in Knowledge Bases**: Markdown is now displayed correctly within knowledge bases, enabling better formatting and clarity of information-rich files.
- 🗂️ **LDAP Import/Loading Issue Resolved**: LDAP user imports process correctly, ensuring smooth onboarding and access without interruption.
- 🌎 **Pinecone Batch Operations and Async Safety**: All Pinecone operations (batch insert, upsert, delete) now run efficiently and safely in an async environment, boosting performance and preventing slowdowns in large-scale RAG jobs.
## [0.6.7] - 2025-05-07
### Added
- 🌐 **Custom Azure TTS API URL Support Added**: You can now define a custom Azure Text-to-Speech endpoint—enabling flexibility for enterprise deployments and regional compliance.
- ⚙️ **TOOL_SERVER_CONNECTIONS Environment Variable Suppor**: Easily configure and deploy tool servers via environment variables, streamlining setup and enabling faster enterprise provisioning.
- 👥 **Enhanced OAuth Group Handling as String or List**: OAuth group data can now be passed as either a list or a comma-separated string, improving compatibility with varied identity provider formats and reducing onboarding friction.
### Fixed
- 🧠 **Embedding with Ollama Proxy Endpoints Restored**: Fixed an issue where missing API config broke embedding for proxied Ollama models—ensuring consistent performance and compatibility.
- 🔐 **OIDC OAuth Login Issue Resolved**: Users can once again sign in seamlessly using OpenID Connect-based OAuth, eliminating login interruptions and improving reliability.
- 📝 **Notes Feature Access Fixed for Non-Admins**: Fixed an issue preventing non-admin users from accessing the Notes feature, restoring full cross-role collaboration capabilities.
- 🖼️ **Tika Loader Image Extraction Problem Resolved**: Ensured TikaLoader now processes 'extract_images' parameter correctly, restoring complete file extraction functionality in document workflows.
- 🎨 **Automatic1111 Image Model Setting Applied Properly**: Fixed an issue where switching to a specific image model via the UI wasnt reflected in generation, re-enabling full visual creativity control.
- 🏷️ **Multiple XML Tags in Messages Now Parsed Correctly**: Fixed parsing issues when messages included multiple XML-style tags, ensuring clean and unbroken rendering of rich content in chats.
- 🖌️ **OpenAI Image Generation Issues Resolved**: Resolved broken image output when using OpenAIs image generation, ensuring fully functional visual creation workflows.
- 🔎 **Tool Server Settings UI Privacy Restored**: Prevented restricted users from accessing tool server settings via search—restoring tight permissions control and safeguarding sensitive configurations.
- 🎧 **WebM Audio Transcription Now Supported**: Fixed an issue where WebM files failed during audio transcription—these formats are now fully supported, ensuring smoother voice note workflows and broader file compatibility.
## [0.6.6] - 2025-05-05
### Added
- 📝 **AI-Enhanced Notes (With Audio Transcription)**: Effortlessly create notes, attach meeting or voice audio, and let the AI instantly enhance, summarize, or refine your notes using audio transcriptions—making your documentation smarter, cleaner, and more insightful with minimal effort.
- 🔊 **Meeting Audio Recording & Import**: Seamlessly record audio from your meetings or capture screen audio and attach it to your notes—making it easier to revisit, annotate, and extract insights from important discussions.
- 📁 **Import Markdown Notes Effortlessly**: Bring your existing knowledge library into Open WebUI by importing your Markdown notes, so you can leverage all advanced note management and AI features right away.
- 👥 **Notes Permissions by User Group**: Fine-tune access and editing rights for notes based on user roles or groups, so you can delegate writing or restrict sensitive information as needed.
- ☁️ **OneDrive & SharePoint Integration**: Keep your content in sync by connecting notes and files directly with OneDrive or SharePoint—unlocking fast enterprise import/export and seamless collaboration with your existing workflows.
- 🗂️ **Paginated User List in Admin Panel**: Effortlessly manage and search through large teams via the new paginated user list—saving time and streamlining user administration in big organizations.
- 🕹️ **Granular Chat Share & Export Permissions**: Enjoy enhanced control over who can share or export chats, enabling tighter governance and privacy in team and enterprise settings.
- 🛑 **User Role Change Confirmation Dialog**: Reduce accidental privilege changes with a required confirmation step before updating user roles—improving security and preventing costly mistakes in team management.
- 🚨 **Audit Log for Failed Login Attempts**: Quickly detect unauthorized access attempts or troubleshoot user login problems with detailed logs of failed authentication right in the audit trail.
- 💡 **Dedicated 'Generate Title' Button for Chats**: Swiftly organize every conversation—tap the new button to let AI create relevant, clear titles for all your chats, saving time and reducing clutter.
- 💬 **Notification Sound Always-On Option**: Take control of your notifications by setting sound alerts to always play—helping you stay on top of important updates in busy environments.
- 🆔 **S3 File Tagging Support**: Uploaded files to S3 now include tags for better organization, searching, and integration with your file management policies.
- 🛡️ **OAuth Blocked Groups Support**: Gain more control over group-based access by explicitly blocking specified OAuth groups—ideal for complex identity or security requirements.
- 🚀 **Optimized Faster Web Search & Multi-Threaded Queries**: Enjoy dramatically faster web search and RAG (retrieval augmented generation) with revamped multi-threaded search—get richer, more accurate results in less time.
- 🔍 **All-Knowledge Parallel Search**: Searches across your entire knowledge base now happen in parallel even in non-hybrid mode, speeding up responses and improving knowledge accuracy for every question.
- 🌐 **New Firecrawl & Yacy Web Search Integrations**: Expand your world of information with two new advanced search engines—Firecrawl for deeper web insight and Yacy for decentralized, privacy-friendly search capabilities.
- 🧠 **Configurable Docling OCR Engine & Language**: Use environment variables to fine-tune Docling OCR engine and supported languages for smarter, more tailored document extraction and RAG workflows.
- 🗝️ **Enhanced Sentence Transformers Configuration**: Added new environment variables for easier set up and advanced customization of Sentence Transformers—ensuring best fit for your embedding needs.
- 🌲 **Pinecone Vector Database Integration**: Index, search, and manage knowledge at enterprise scale with full native support for Pinecone as your vector database—effortlessly handle even the biggest document sets.
- 🔄 **Automatic Requirements Installation for Tools & Functions**: Never worry about lost dependencies on restart—external function and tool requirements are now auto-installed at boot, ensuring tools always “just work.”
- 🔒 **Automatic Sign-Out on Token Expiry**: Security is smarter—users are now automatically logged out if their authentication token expires, protecting sensitive content and ensuring compliance without disruption.
- 🎬 **Automatic YouTube Embed Detection**: Paste YouTube links and see instant in-chat video embeds—no more manual embedding, making knowledge sharing and media consumption even easier for every team.
- 🔄 **Expanded Language & Locale Support**: Translations for Danish, French, Russian, Traditional Chinese, Simplified Chinese, Thai, Catalan, German, and Korean have been upgraded, offering smoother, more natural user experiences across the platform.
### Fixed
- 🔒 **Tighter HTML Token Security**: HTML rendering is now restricted to admin-uploaded tokens only, reducing any risk of XSS and keeping your data safe.
- 🔐 **Refined HTML Security and Token Handling**: Further hardened how HTML tokens and content are handled, guaranteeing even stronger resistance to security vulnerabilities and attacks.
- 🔏 **Correct Model Usage with Ollama Proxy Prefixes**: Enhanced model reference handling so proxied models in Ollama always download and run correctly—even when using custom prefixes.
- 📥 **Video File Upload Handling**: Prevented video files from being misclassified as text, fixing bugs with uploads and ensuring media files work as expected.
- 🔄 **No More Dependent WebSocket Sequential Delays**: Streamlined WebSocket operation to prevent delays and maintain snappy real-time collaboration, especially in multi-user environments.
- 🛠️ **More Robust Action Module Execution**: Multiple actions in a module now trigger as designed, increasing automation and scripting flexibility.
- 📧 **Notification Webhooks**: Ensured that notification webhooks are always sent for user events, even when the user isnt currently active.
- 🗂️ **Smarter Knowledge Base Reindexing**: Knowledge reindexing continues even when corrupt or missing collections are encountered, keeping your search features running reliably.
- 🏷️ **User Import with Profile Images**: When importing users, their profile images now come along—making onboarding and collaboration visually clearer from day one.
- 💬 **OpenAI o-Series Universal Support**: All OpenAI o-series models are now seamlessly recognized and supported, unlocking more advanced capabilities and model choices for every workflow.
### Changed
- 📜 **Custom License Update & Contributor Agreement**: Open WebUI now operates under a custom license with Contributor License Agreement required by default—see https://docs.openwebui.com/license/ for details, ensuring sustainable open innovation for the community.
- 🔨 **CUDA Docker Images Updated to 12.8**: Upgraded CUDA image support for faster, more compatible model inference and futureproof GPU performance in your AI infrastructure.
- 🧱 **General Backend Refactoring for Reliability**: Continuous stability improvements streamline backend logic, reduce errors, and lay a stronger foundation for the next wave of feature releases—all under the hood for a more dependable WebUI.
## [0.6.5] - 2025-04-14
### Added
- 🛂 **Granular Voice Feature Permissions Per User Group**: Admins can now separately manage access to Speech-to-Text (record voice), Text-to-Speech (read aloud), and Tool Calls for each user group—giving teams tighter control over voice features and enhanced governance across roles.
- 🗣️ **Toggle Voice Activity Detection (VAD) for Whisper STT**: New environment variable lets you enable/disable VAD filtering with built-in Whisper speech-to-text, giving you flexibility to optimize for different audio quality and response accuracy levels.
- 📋 **Copy Formatted Response Mode**: You can now enable “Copy Formatted” in Settings > Interface to copy AI responses exactly as styled (with rich formatting, links, and structure preserved), making it faster and cleaner to paste into documents, emails, or reports.
- ⚙️ **Backend Stability and Performance Enhancements**: General backend refactoring improves system resilience, consistency, and overall reliability—offering smoother performance across workflows whether chatting, generating media, or using external tools.
- 🌎 **Translation Refinements Across Multiple Languages**: Updated translations deliver smoother language localization, clearer labels, and improved international usability throughout the UI—ensuring a better experience for non-English speakers.
### Fixed
- 🛠️ **LDAP Login Reliability Restored**: Resolved a critical issue where some LDAP setups failed due to attribute parsing—ensuring consistent, secure, and seamless user authentication across enterprise deployments.
- 🖼️ **Image Generation in Temporary Chats Now Works Properly**: Fixed a bug where image outputs werent generated during temporary chats—visual content can now be used reliably in all chat modes without interruptions.
## [0.6.4] - 2025-04-12
### Fixed
- 🛠️ **RAG_TEMPLATE Display Issue Resolved**: Fixed a formatting problem where the custom RAG_TEMPLATE wasn't correctly rendered in the interface—ensuring that custom retrieval prompts now appear exactly as intended for more reliable prompt engineering.
## [0.6.3] - 2025-04-12
### Added
- 🧪 **Auto-Artifact Detection Toggle**: Automatically detects artifacts in results—but now you can disable this behavior under advanced settings for full control.
- 🖼️ **Widescreen Mode for Shared Chats**: Shared link conversations now support widescreen layouts—perfect for presentations or easier review across wider displays.
- 🔁 **Reindex Knowledge Files on Demand**: Admins can now trigger reindexing of all knowledge files after changing embeddings—ensuring immediate alignment with new models for optimal RAG performance.
- 📄 **OpenAPI YAML Format Support**: External tools can now use YAML-format OpenAPI specs—making integration simpler for developers familiar with YAML-based configurations.
- 💬 **Message Content Copy Behavior**: Copy action now excludes 'details' tags—streamlining clipboard content when sharing or pasting summaries elsewhere.
- 🧭 **Sougou Web Search Integration**: New search engine option added—enhancing global relevance and diversity of search sources for multilingual users.
- 🧰 **Frontend Web Loader Engine Configuration**: Admins can now set preferred web loader engine for RAG workflows directly from the frontend—offering more control across setups.
- 👥 **Multi-Model Chat Permission Control**: Admins can manage access to multi-model chats per user group—allowing tighter governance in team environments.
- 🧱 **Persistent Configuration Can Be Disabled**: New environment variable lets advanced users and hosts turn off persistent configs—ideal for volatile or stateless deployments.
- 🧠 **Elixir Code Highlighting Support**: Elixir syntax is now beautifully rendered in code blocks—perfect for developers using this language in AI or automation projects.
- 🌐 **PWA External Manifest URL Support**: You can now define an external manifest.json—integrate Open WebUI seamlessly in managed or proxy-based PWA environments like Cloudflare Zero Trust.
- 🧪 **Azure AI Speech-to-Text Provider Integration**: Easily transcribe large audio files (up to 200MB) with high accuracy using Microsoft's Azure STT—fully configurable in Audio Settings.
- 🔏 **PKCE (Code Challenge Method) Support for OIDC**: Enhance your OIDC login security with Proof Key for Code Exchange—ideal for zero-trust and native client apps.
- ✨ **General UI/UX Enhancements**: Numerous refinements across layout, styling, and tool interactions—reducing visual noise and improving overall usability across key workflows.
- 🌍 **Translation Updates Across Multiple Languages**: Refined Catalan, Russian, Chinese (Simplified & Traditional), Hungarian, and Spanish translations for clearer navigation and instructions globally.
### Fixed
- 💥 **Chat Completion Error with Missing Models Resolved**: Fixed internal server error when referencing a model that doesnt exist—ensuring graceful fallback and clear error guidance.
- 🔧 **Correct Knowledge Base Citations Restored**: Citations generated by RAG workflows now show accurate references—ensuring verifiability in outputs from sourced content.
- 🎙️ **Broken OGG/WebM Audio Upload Handling for OpenAI Fixed**: Uploading OGG or WebM files now converts properly to WAV before transcription—restoring accurate AI speech recognition workflows.
- 🔐 **Tool Server 'Session' Authentication Restored**: Previously broken session auth on external tool servers is now fully functional—ensuring secure and seamless access to connected tools.
- 🌐 **Folder-Based Chat Rename Now Updates Correctly**: Renaming chats in folders now reflects instantly everywhere—improving chat organization and clarity.
- 📜 **KaTeX Overflow Displays Fixed**: Math expressions now stay neatly within message bounds—preserving layout consistency even with long formulas.
- 🚫 **Stopping Ongoing Chat Fixed**: You can now return to an active (ongoing) chat and stop generation at any time—ensuring full control over sessions.
- 🔧 **TOOL_SERVERS / TOOL_SERVER_CONNECTIONS Indexing Issue Fixed**: Fixed a mismatch between tool lists and their access paths—restoring full function and preventing confusion in tool management.
- 🔐 **LDAP Login Handles Multiple Emails**: When LDAP returns multiple email attributes, the first valid one is now used—ensuring login success and account consistency.
- 🧩 **Model Visibility Toggle Fix**: Toggling model visibility now works even for untouched models—letting admins smoothly manage user access across base models.
- ⚙️ **Cross-Origin manifest.json Now Loads Properly**: Compatibility issues with Cloudflare Zero Trust (and others) resolved, allowing manifest.json to load behind authenticated proxies.
### Changed
- 🔒 **Default Access Scopes Set to Private for All Resources**: Models, tools, and knowledge are now private by default when created—ensuring better baseline security and visibility controls.
- 🧱 **General Backend Refactoring for Stability**: Numerous invisible improvements enhance backend scalability, security, and maintainability—powering upcoming features with a stronger foundation.
- 🧩 **Stable Dependency Upgrades**: Updated key platform libraries—Chromadb (0.6.3), pgvector (0.4.0), Azure Identity (1.21.0), and Youtube Transcript API (1.0.3)—for improved compatibility, functionality, and security.
## [0.6.2] - 2025-04-06
### Added
- 🌍 **Improved Global Language Support**: Expanded and refined translations across multiple languages to enhance clarity and consistency for international users.
### Fixed
- 🛠️ **Accurate Tool Descriptions from OpenAPI Servers**: External tools now use full endpoint descriptions instead of summaries when generating tool specifications—helping AI models understand tool purpose more precisely and choose the right tool more accurately in tool workflows.
- 🔧 **Precise Web Results Source Attribution**: Fixed a key issue where all web search results showed the same source ID—now each result gets its correct and distinct source, ensuring accurate citations and traceability.
- 🔍 **Clean Web Search Retrieval**: Web search now retains only results from URLs where real content was successfully fetched—improving accuracy and removing empty or broken links from citations.
- 🎵 **Audio File Upload Response Restored**: Resolved an issue where uploading audio files did not return valid responses, restoring smooth file handling for transcription and audio-based workflows.
### Changed
- 🧰 **General Backend Refactoring**: Multiple behind-the-scenes improvements streamline backend performance, reduce complexity, and ensure a more stable, maintainable system overall—making everything smoother without changing your workflow.
## [0.6.1] - 2025-04-05
### Added
- 🛠️ **Global Tool Servers Configuration**: Admins can now centrally configure global external tool servers from Admin Settings > Tools, allowing seamless sharing of tool integrations across all users without manual setup per user.
- 🔐 **Direct Tool Usage Permission for Users**: Introduced a new user-level permission toggle that grants non-admin users access to direct external tools, empowering broader team collaboration while maintaining control.
- 🧠 **Mistral OCR Content Extraction Support**: Added native support for Mistral OCR as a high-accuracy document loader, drastically improving text extraction from scanned documents in RAG workflows.
- 🖼️ **Tools Indicator UI Redesign**: Enhanced message input now smartly displays both built-in and external tools via a unified dropdown, making it simpler and more intuitive to activate tools during conversations.
- 📄 **RAG Prompt Improved and More Coherent**: Default RAG system prompt has been revised to be more clear and citation-focused—admins can leave the template field empty to use this new gold-standard prompt.
- 🧰 **Performance & Developer Improvements**: Major internal restructuring of several tool-related components, simplifying styling and merging external/internal handling logic, resulting in better maintainability and performance.
- 🌍 **Improved Translations**: Updated translations for Tibetan, Polish, Chinese (Simplified & Traditional), Arabic, Russian, Ukrainian, Dutch, Finnish, and French to improve clarity and consistency across the interface.
### Fixed
- 🔑 **External Tool Server API Key Bug Resolved**: Fixed a critical issue where authentication headers were not being sent when calling tools from external OpenAPI tool servers, ensuring full security and smooth tool operations.
- 🚫 **Conditional Export Button Visibility**: UI now gracefully hides export buttons when there's nothing to export in models, prompts, tools, or functions, improving visual clarity and reducing confusion.
- 🧪 **Hybrid Search Failure Recovery**: Resolved edge case in parallel hybrid search where empty or unindexed collections caused backend crashes—these are now cleanly skipped to ensure system stability.
- 📂 **Admin Folder Deletion Fix**: Addressed an issue where folders created in the admin workspace couldn't be deleted, restoring full organizational flexibility for admins.
- 🔐 **Improved Generic Error Feedback on Login**: Authentication errors now show simplified, non-revealing messages for privacy and improved UX, especially with federated logins.
- 📝 **Tool Message with Images Improved**: Enhanced how tool-generated messages with image outputs are shown in chat, making them more readable and consistent with the overall UI design.
- ⚙️ **Auto-Exclusion for Broken RAG Collections**: Auto-skips document collections that fail to fetch data or return "None", preventing silent errors and streamlining retrieval workflows.
- 📝 **Docling Text File Handling Fix**: Fixed file parsing inconsistency that broke docling-based RAG functionality for certain plain text files, ensuring wider file compatibility.
## [0.6.0] - 2025-03-31
### Added
- 🧩 **External Tool Server Support via OpenAPI**: Connect Open WebUI to any OpenAPI-compatible REST server instantly—offering immediate integration with thousands of developer tools, SDKs, and SaaS systems for powerful extensibility. Learn more: https://github.com/open-webui/openapi-servers
- 🛠️ **MCP Server Support via MCPO**: You can now convert and expose your internal MCP tools as interoperable OpenAPI HTTP servers within Open WebUI for seamless, plug-n-play AI toolchain creation. Learn more: https://github.com/open-webui/mcpo
- 📨 **/messages Chat API Endpoint Support**: For power users building external AI systems, new endpoints allow precise control of messages asynchronously—feed long-running external responses into Open WebUI chats without coupling with the frontend.
- 📝 **Client-Side PDF Generation**: PDF exports are now generated fully client-side for drastically improved output quality—perfect for saving conversations or documents.
- 💼 **Enforced Temporary Chats Mode**: Admins can now enforce temporary chat sessions by default to align with stringent data retention and compliance requirements.
- 🌍 **Public Resource Sharing Permission Controls**: Fine-grained user group permissions now allow enabling/disabling public sharing for models, knowledge, prompts, and tools—ideal for privacy, team control, and internal deployments.
- 📦 **Custom pip Options for Tools/Functions**: You can now specify custom pip installation options with "PIP_OPTIONS", "PIP_PACKAGE_INDEX_OPTIONS" environment variables—improving compatibility, support for private indexes, and better control over Python environments.
- 🔢 **Editable Message Counter**: You can now double-click the message count number and jump straight to editing the index—quickly navigate complex chats or regenerate specific messages precisely.
- 🧠 **Embedding Prefix Support Added**: Add custom prefixes to your embeddings for instruct-style tokens, enabling stronger model alignment and more consistent RAG performance.
- 🙈 **Ability to Hide Base Models**: Optionally hide base models from the UI, helping users streamline model visibility and limit access to only usable endpoints..
- 📚 **Docling Content Extraction Support**: Open WebUI now supports Docling as a content extraction engine, enabling smarter and more accurate parsing of complex file formats—ideal for advanced document understanding and Retrieval-Augmented Generation (RAG) workflows.
- 🗃️ **Redis Sentinel Support Added**: Enhance deployment redundancy with support for Redis Sentinel for highly available, failover-safe Redis-based caching or pub/sub.
- 📚 **JSON Schema Format for Ollama**: Added support for defining the format using JSON schema in Ollama-compatible models, improving flexibility and validation of model outputs.
- 🔍 **Chat Sidebar Search "Clear” Button**: Quickly clear search filters in chat sidebar using the new ✖️ button—streamline your chat navigation with one click.
- 🗂️ **Auto-Focus + Enter Submit for Folder Name**: When creating a new folder, the system automatically enters rename mode with name preselected—simplifying your org workflow.
- 🧱 **Markdown Alerts Rendering**: Blockquotes with syntax hinting (e.g. ⚠️, , ✅) now render styled Markdown alert banners, making messages and documentation more visually structured.
- 🔁 **Hybrid Search Runs in Parallel Now**: Hybrid (BM25 + embedding) search components now run in parallel—dramatically reducing response times and speeding up document retrieval.
- 📋 **Cleaner UI for Tool Call Display**: Optimized the visual layout of called tools inside chat messages for better clarity and reduced visual clutter.
- 🧪 **Playwright Timeout Now Configurable**: Default timeout for Playwright processes is now shorter and adjustable via environment variables—making web scraping more robust and tunable to environments.
- 📈 **OpenTelemetry Support for Observability**: Open WebUI now integrates with OpenTelemetry, allowing you to connect with tools like Grafana, Jaeger, or Prometheus for detailed performance insights and real-time visibility—entirely opt-in and fully self-hosted. Even if enabled, no data is ever sent to us, ensuring your privacy and ownership over all telemetry data.
- 🛠 **General UI Enhancements & UX Polish**: Numerous refinements across sidebar, code blocks, modal interactions, button alignment, scrollbar visibility, and folder behavior improve overall fluidity and usability of the interface.
- 🧱 **General Backend Refactoring**: Numerous backend components have been refactored to improve stability, maintainability, and performance—ensuring a more consistent and reliable system across all features.
- 🌍 **Internationalization Language Support Updates**: Added Estonian and Galician languages, improved Spanish (fully revised), Traditional Chinese, Simplified Chinese, Turkish, Catalan, Ukrainian, and German for a more localized and inclusive interface.
### Fixed
- 🧑‍💻 **Firefox Input Height Bug**: Text input in Firefox now maintains proper height, ensuring message boxes look consistent and behave predictably.
- 🧾 **Tika Blank Line Bug**: PDFs processed with Apache Tika 3.1.0.0 no longer introduce excessive blank lines—improving RAG output quality and visual cleanliness.
- 🧪 **CSV Loader Encoding Issues**: CSV files with unknown encodings now automatically detect character sets, resolving import errors in non-UTF-8 datasets.
- ✅ **LDAP Auth Config Fix**: Path to certificate file is now optional for LDAP setups, fixing authentication trouble for users without preconfigured cert paths.
- 📥 **File Deletion in Bypass Mode**: Resolved issue where files couldnt be deleted from knowledge when “bypass embedding” mode was enabled.
- 🧩 **Hybrid Search Result Sorting & Deduplication Fixed**: Fixed citation and sorting issues in RAG hybrid and reranker modes, ensuring retrieved documents are shown in correct order per score.
- 🧷 **Model Export/Import Broken for a Single Model**: Fixed bug where individual models couldnt be exported or re-imported, restoring full portability.
- 📫 **Auth Redirect Fix**: Logged-in users are now routed properly without unnecessary login prompts when already authenticated.
### Changed
- 🧠 **Prompt Autocompletion Disabled By Default**: Autocomplete suggestions while typing are now disabled unless explicitly re-enabled in user preferences—reduces distractions while composing prompts for advanced users.
- 🧾 **Normalize Citation Numbering**: Source citations now properly begin from "1" instead of "0"—improving consistency and professional presentation in AI outputs.
- 📚 **Improved Error Handling from Pipelines**: Pipelines now show the actual returned error message from failed tasks rather than generic "Connection closed"—making debugging far more user-friendly.
### Removed
- 🧾 **ENABLE_AUDIT_LOGS Setting Removed**: Deprecated setting “ENABLE_AUDIT_LOGS” has been fully removed—now controlled via “AUDIT_LOG_LEVEL” instead.
## [0.5.20] - 2025-03-05
### Added
- **⚡ Toggle Code Execution On/Off**: You can now enable or disable code execution, providing more control over security, ensuring a safer and more customizable experience.
### Fixed
- **📜 Pinyin Keyboard Enter Key Now Works Properly**: Resolved an issue where the Enter key for Pinyin keyboards was not functioning as expected, ensuring seamless input for Chinese users.
- **🖼️ Web Manifest Loading Issue Fixed**: Addressed inconsistencies with 'site.webmanifest', guaranteeing proper loading and representation of the app across different browsers and devices.
- **📦 Non-Root Container Issue Resolved**: Fixed a critical issue where the UI failed to load correctly in non-root containers, ensuring reliable deployment in various environments.
## [0.5.19] - 2025-03-04
### Added
- **📊 Logit Bias Parameter Support**: Fine-tune conversation dynamics by adjusting the Logit Bias parameter directly in chat settings, giving you more control over model responses.
- **⌨️ Customizable Enter Behavior**: You can now configure Enter to send messages only when combined with Ctrl (Ctrl+Enter) via Settings > Interface, preventing accidental message sends.
- **📝 Collapsible Code Blocks**: Easily collapse long code blocks to declutter your chat, making it easier to focus on important details.
- **🏷️ Tag Selector in Model Selector**: Quickly find and categorize models with the new tag filtering system in the Model Selector, streamlining model discovery.
- **📈 Experimental Elasticsearch Vector DB Support**: Now supports Elasticsearch as a vector database, offering more flexibility for data retrieval in Retrieval-Augmented Generation (RAG) workflows.
- **⚙️ General Reliability Enhancements**: Various stability improvements across the WebUI, ensuring a smoother, more consistent experience.
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
### Fixed
- **🔄 "Stream" Hook Activation**: Fixed an issue where the "Stream" hook only worked when globally enabled, ensuring reliable real-time filtering.
- **📧 LDAP Email Case Sensitivity**: Resolved an issue where LDAP login failed due to email case sensitivity mismatches, improving authentication reliability.
- **💬 WebSocket Chat Event Registration**: Fixed a bug preventing chat event listeners from being registered upon sign-in, ensuring real-time updates work properly.
## [0.5.18] - 2025-02-27
### Fixed
- **🌐 Open WebUI Now Works Over LAN in Insecure Context**: Resolved an issue preventing Open WebUI from functioning when accessed over a local network in an insecure context, ensuring seamless connectivity.
- **🔄 UI Now Reflects Deleted Connections Instantly**: Fixed an issue where deleting a connection did not update the UI in real time, ensuring accurate system state visibility.
- **🛠️ Models Now Display Correctly with ENABLE_FORWARD_USER_INFO_HEADERS**: Addressed a bug where models were not visible when ENABLE_FORWARD_USER_INFO_HEADERS was set, restoring proper model listing.
## [0.5.17] - 2025-02-27
### Added
- **🚀 Instant Document Upload with Bypass Embedding & Retrieval**: Admins can now enable "Bypass Embedding & Retrieval" in Admin Settings > Documents, significantly speeding up document uploads and ensuring full document context is retained without chunking.
- **🔎 "Stream" Hook for Real-Time Filtering**: The new "stream" hook allows dynamic real-time message filtering. Learn more in our documentation (https://docs.openwebui.com/features/plugin/functions/filter).
- **☁️ OneDrive Integration**: Early support for OneDrive storage integration has been introduced, expanding file import options.
- **📈 Enhanced Logging with Loguru**: Backend logging has been improved with Loguru, making debugging and issue tracking far more efficient.
- **⚙️ General Stability Enhancements**: Backend and frontend refactoring improves performance, ensuring a smoother and more reliable user experience.
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
### Fixed
- **🔄 Reliable Model Imports from the Community Platform**: Resolved import failures, allowing seamless integration of community-shared models without errors.
- **📊 OpenAI Usage Statistics Restored**: Fixed an issue where OpenAI usage metrics were not displaying correctly, ensuring accurate tracking of usage data.
- **🗂️ Deduplication for Retrieved Documents**: Documents retrieved during searches are now intelligently deduplicated, meaning no more redundant results—helping to keep information concise and relevant.
### Changed
- **📝 "Full Context Mode" Renamed for Clarity**: The "Full Context Mode" toggle in Web Search settings is now labeled "Bypass Embedding & Retrieval" for consistency across the UI.
## [0.5.16] - 2025-02-20
### Fixed
- **🔍 Web Search Retrieval Restored**: Resolved a critical issue that broke web search retrieval by reverting deduplication changes, ensuring complete and accurate search results once again.
## [0.5.15] - 2025-02-20
### Added
- **📄 Full Context Mode for Local Document Search (RAG)**: Toggle full context mode from Admin Settings > Documents to inject entire document content into context, improving accuracy for models with large context windows—ideal for deep context understanding.
- **🌍 Smarter Web Search with Agentic Workflows**: Web searches now intelligently gather and refine multiple relevant terms, similar to RAG handling, delivering significantly better search results for more accurate information retrieval.
- **🔎 Experimental Playwright Support for Web Loader**: Web content retrieval is taken to the next level with Playwright-powered scraping for enhanced accuracy in extracted web data.
- **☁️ Experimental Azure Storage Provider**: Early-stage support for Azure Storage allows more cloud storage flexibility directly within Open WebUI.
- **📊 Improved Jupyter Code Execution with Plots**: Interactive coding now properly displays inline plots, making data visualization more seamless inside chat interactions.
- **⏳ Adjustable Execution Timeout for Jupyter Interpreter**: Customize execution timeout (default: 60s) for Jupyter-based code execution, allowing longer or more constrained execution based on your needs.
- **▶️ "Running..." Indicator for Jupyter Code Execution**: A visual indicator now appears while code execution is in progress, providing real-time status updates on ongoing computations.
- **⚙️ General Backend & Frontend Stability Enhancements**: Extensive refactoring improves reliability, performance, and overall user experience for a more seamless Open WebUI.
- **🌍 Translation Updates**: Various international translation refinements ensure better localization and a more natural user interface experience.
### Fixed
- **📱 Mobile Hover Issue Resolved**: Users can now edit responses smoothly on mobile without interference, fixing a longstanding hover issue.
- **🔄 Temporary Chat Message Duplication Fixed**: Eliminated buggy behavior where messages were being unnecessarily repeated in temporary chat mode, ensuring a smooth and consistent conversation flow.
## [0.5.14] - 2025-02-17
### Fixed
- **🔧 Critical Import Error Resolved**: Fixed a circular import issue preventing 'override_static' from being correctly imported in 'open_webui.config', ensuring smooth system initialization and stability.
## [0.5.13] - 2025-02-17
### Added
- **🌐 Full Context Mode for Web Search**: Enable highly accurate web searches by utilizing full context mode—ideal for models with large context windows, ensuring more precise and insightful results.
- **⚡ Optimized Asynchronous Web Search**: Web searches now load significantly faster with optimized async support, providing users with quicker, more efficient information retrieval.
- **🔄 Auto Text Direction for RTL Languages**: Automatic text alignment based on language input, ensuring seamless conversation flow for Arabic, Hebrew, and other right-to-left scripts.
- **🚀 Jupyter Notebook Support for Code Execution**: The "Run" button in code blocks can now use Jupyter for execution, offering a powerful, dynamic coding experience directly in the chat.
- **🗑️ Message Delete Confirmation Dialog**: Prevent accidental deletions with a new confirmation prompt before removing messages, adding an additional layer of security to your chat history.
- **📥 Download Button for SVG Diagrams**: SVG diagrams generated within chat can now be downloaded instantly, making it easier to save and share complex visual data.
- **✨ General UI/UX Improvements and Backend Stability**: A refined interface with smoother interactions, improved layouts, and backend stability enhancements for a more reliable, polished experience.
### Fixed
- **🛠️ Temporary Chat Message Continue Button Fixed**: The "Continue Response" button for temporary chats now works as expected, ensuring an uninterrupted conversation flow.
### Changed
- **📝 Prompt Variable Update**: Deprecated square bracket '[]' indicators for prompt variables; now requires double curly brackets '{{}}' for consistency and clarity.
- **🔧 Stability Enhancements**: Error handling improved in chat history, ensuring smoother operations when reviewing previous messages.
## [0.5.12] - 2025-02-13 ## [0.5.12] - 2025-02-13
### Added ### Added

View file

@ -2,13 +2,13 @@
## Our Pledge ## Our Pledge
As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. As members, contributors, and leaders of this community, we pledge to make participation in our project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct. We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
## Why These Standards Are Important ## Why These Standards Are Important
Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved. Projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.** Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
@ -79,7 +79,7 @@ This approach ensures that disruptive behaviors are addressed swiftly and decisi
## Why Zero Tolerance Is Necessary ## Why Zero Tolerance Is Necessary
Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do. Projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive. By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.

View file

@ -0,0 +1,7 @@
# Open WebUI Contributor License Agreement
By submitting my contributions to Open WebUI, I grant Open WebUI full freedom to use my work in any way they choose, under any terms they like, both now and in the future. This approach helps ensure the project remains unified, flexible, and easy to maintain, while empowering Open WebUI to respond quickly to the needs of its users and the wider community.
Taking part in this process means my work can be seamlessly integrated and combined with others, ensuring longevity and adaptability for everyone who benefits from the Open WebUI project. This collaborative approach strengthens the projects future and helps guarantee that improvements can always be shared and distributed in the most effective way possible.
**_To the fullest extent permitted by law, my contributions are provided on an “as is” basis, with no warranties or guarantees of any kind, and I disclaim any liability for any issues or damages arising from their use or incorporation into the project, regardless of the type of legal claim._**

View file

@ -1,64 +0,0 @@
# Run with
# caddy run --envfile ./example.env --config ./Caddyfile.localhost
#
# This is configured for
# - Automatic HTTPS (even for localhost)
# - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
# - CORS
# - HTTP Basic Auth API Tokens (uncomment basicauth section)
# CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
(cors-api) {
@match-cors-api-preflight method OPTIONS
handle @match-cors-api-preflight {
header {
Access-Control-Allow-Origin "{http.request.header.origin}"
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
Access-Control-Allow-Credentials "true"
Access-Control-Max-Age "3600"
defer
}
respond "" 204
}
@match-cors-api-request {
not {
header Origin "{http.request.scheme}://{http.request.host}"
}
header Origin "{http.request.header.origin}"
}
handle @match-cors-api-request {
header {
Access-Control-Allow-Origin "{http.request.header.origin}"
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
Access-Control-Allow-Credentials "true"
Access-Control-Max-Age "3600"
defer
}
}
}
# replace localhost with example.com or whatever
localhost {
## HTTP Basic Auth
## (uncomment to enable)
# basicauth {
# # see .example.env for how to generate tokens
# {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
# }
handle /api/* {
# Comment to disable CORS
import cors-api
reverse_proxy localhost:11434
}
# Same-Origin Static Web Server
file_server {
root ./build/
}
}

View file

@ -4,7 +4,7 @@
ARG USE_CUDA=false ARG USE_CUDA=false
ARG USE_OLLAMA=false ARG USE_OLLAMA=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
ARG USE_CUDA_VER=cu121 ARG USE_CUDA_VER=cu128
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
@ -26,6 +26,9 @@ ARG BUILD_HASH
WORKDIR /app WORKDIR /app
# to store git revision in build
RUN apk add --no-cache git
COPY package.json package-lock.json ./ COPY package.json package-lock.json ./
RUN npm ci RUN npm ci
@ -132,7 +135,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
# install python dependencies # install python dependencies
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
RUN pip3 install uv && \ RUN pip3 install --no-cache-dir uv && \
if [ "$USE_CUDA" = "true" ]; then \ if [ "$USE_CUDA" = "true" ]; then \
# If you use CUDA the whisper and embedding model will be downloaded on first use # If you use CUDA the whisper and embedding model will be downloaded on first use
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \

View file

@ -1,4 +1,4 @@
Copyright (c) 2023-2025 Timothy Jaeryang Baek Copyright (c) 2023-2025 Timothy Jaeryang Baek (Open WebUI)
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
@ -15,6 +15,12 @@ modification, are permitted provided that the following conditions are met:
contributors may be used to endorse or promote products derived from contributors may be used to endorse or promote products derived from
this software without specific prior written permission. this software without specific prior written permission.
4. Notwithstanding any other provision of this License, and as a material condition of the rights granted herein, licensees are strictly prohibited from altering, removing, obscuring, or replacing any "Open WebUI" branding, including but not limited to the name, logo, or any visual, textual, or symbolic identifiers that distinguish the software and its interfaces, in any deployment or distribution, regardless of the number of users, except as explicitly set forth in Clauses 5 and 6 below.
5. The branding restriction enumerated in Clause 4 shall not apply in the following limited circumstances: (i) deployments or distributions where the total number of end users (defined as individual natural persons with direct access to the application) does not exceed fifty (50) within any rolling thirty (30) day period; (ii) cases in which the licensee is an official contributor to the codebase—with a substantive code change successfully merged into the main branch of the official codebase maintained by the copyright holder—who has obtained specific prior written permission for branding adjustment from the copyright holder; or (iii) where the licensee has obtained a duly executed enterprise license expressly permitting such modification. For all other cases, any removal or alteration of the "Open WebUI" branding shall constitute a material breach of license.
6. All code, modifications, or derivative works incorporated into this project prior to the incorporation of this branding clause remain licensed under the BSD 3-Clause License, and prior contributors retain all BSD-3 rights therein; if any such contributor requests the removal of their BSD-3-licensed code, the copyright holder will do so, and any replacement code will be licensed under the project's primary license then in effect. By contributing after this clause's adoption, you agree to the project's Contributor License Agreement (CLA) and to these updated terms for all new contributions.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

View file

@ -7,16 +7,20 @@
![GitHub language count](https://img.shields.io/github/languages/count/open-webui/open-webui) ![GitHub language count](https://img.shields.io/github/languages/count/open-webui/open-webui)
![GitHub top language](https://img.shields.io/github/languages/top/open-webui/open-webui) ![GitHub top language](https://img.shields.io/github/languages/top/open-webui/open-webui)
![GitHub last commit](https://img.shields.io/github/last-commit/open-webui/open-webui?color=red) ![GitHub last commit](https://img.shields.io/github/last-commit/open-webui/open-webui?color=red)
![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Follama-webui%2Follama-wbui&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
[![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s) [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
[![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck) [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**. **Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
![Open WebUI Demo](./demo.gif) ![Open WebUI Demo](./demo.gif)
> [!TIP]
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
>
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
## Key Features of Open WebUI ⭐ ## Key Features of Open WebUI ⭐
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images. - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
@ -57,9 +61,36 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview! Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
## 🔗 Also Check Out Open WebUI Community! ## Sponsors 🙌
Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Open WebUI! 🚀 #### Emerald
<table>
<tr>
<td>
<a href="https://n8n.io/" target="_blank">
<img src="https://docs.openwebui.com/sponsors/logos/n8n.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
</a>
</td>
<td>
N8N • Does your interface have a backend yet?<br>Try <a href="https://n8n.io/">n8n</a>
</td>
</tr>
<tr>
<td>
<a href="https://warp.dev/open-webui" target="_blank">
<img src="https://docs.openwebui.com/sponsors/logos/warp.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
</a>
</td>
<td>
<a href="https://warp.dev/open-webui">Warp</a> • The intelligent terminal for developers
</td>
</tr>
</table>
---
We are incredibly grateful for the generous support of our sponsors. Their contributions help us to maintain and improve our project, ensuring we can continue to deliver quality work to our community. Thank you!
## How to Install 🚀 ## How to Install 🚀
@ -201,7 +232,7 @@ Discover upcoming features on our roadmap in the [Open WebUI Documentation](http
## License 📜 ## License 📜
This project is licensed under the [BSD-3-Clause License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄 This project is licensed under the [Open WebUI License](LICENSE), a revised BSD-3-Clause license. You receive all the same rights as the classic BSD-3 license: you can use, modify, and distribute the software, including in proprietary and commercial products, with minimal restrictions. The only additional requirement is to preserve the "Open WebUI" branding, as detailed in the LICENSE file. For full terms, see the [LICENSE](LICENSE) document. 📄
## Support 💬 ## Support 💬

View file

@ -73,8 +73,15 @@ def serve(
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH) os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main import open_webui.main # we need set environment variables before importing main
from open_webui.env import UVICORN_WORKERS # Import the workers setting
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*") uvicorn.run(
"open_webui.main:app",
host=host,
port=port,
forwarded_allow_ips="*",
workers=UVICORN_WORKERS,
)
@app.command() @app.command()

File diff suppressed because it is too large Load diff

View file

@ -31,6 +31,7 @@ class ERROR_MESSAGES(str, Enum):
USERNAME_TAKEN = ( USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username." "Uh-oh! This username is already registered. Please choose another username."
) )
PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."

View file

@ -65,10 +65,8 @@ except Exception:
# LOGGING # LOGGING
#################################### ####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels: if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping():
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else: else:
GLOBAL_LOG_LEVEL = "INFO" GLOBAL_LOG_LEVEL = "INFO"
@ -78,6 +76,7 @@ log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
if "cuda_error" in locals(): if "cuda_error" in locals():
log.exception(cuda_error) log.exception(cuda_error)
del cuda_error
log_sources = [ log_sources = [
"AUDIO", "AUDIO",
@ -100,19 +99,19 @@ SRC_LOG_LEVELS = {}
for source in log_sources: for source in log_sources:
log_env_var = source + "_LOG_LEVEL" log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper() SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels: if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}") log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"]) log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI": if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)" WEBUI_NAME += " (Open WebUI)"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
#################################### ####################################
# ENV (dev,test,prod) # ENV (dev,test,prod)
@ -130,7 +129,6 @@ else:
except Exception: except Exception:
PACKAGE_DATA = {"version": "0.0.0"} PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"] VERSION = PACKAGE_DATA["version"]
@ -161,7 +159,6 @@ try:
except Exception: except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML # Convert markdown content to HTML
html_content = markdown.markdown(changelog_content) html_content = markdown.markdown(changelog_content)
@ -192,7 +189,6 @@ for version in soup.find_all("h2"):
changelog_json[version_number] = version_data changelog_json[version_number] = version_data
CHANGELOG = changelog_json CHANGELOG = changelog_json
#################################### ####################################
@ -209,7 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
) )
#################################### ####################################
# WEBUI_BUILD_HASH # WEBUI_BUILD_HASH
#################################### ####################################
@ -244,7 +239,6 @@ if FROM_INIT_PY:
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")) DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")) STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts")) FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
@ -256,7 +250,6 @@ if FROM_INIT_PY:
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend") os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
).resolve() ).resolve()
#################################### ####################################
# Database # Database
#################################### ####################################
@ -321,7 +314,6 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
) )
ENABLE_REALTIME_CHAT_SAVE = ( ENABLE_REALTIME_CHAT_SAVE = (
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true" os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
) )
@ -330,7 +322,23 @@ ENABLE_REALTIME_CHAT_SAVE = (
# REDIS # REDIS
#################################### ####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
####################################
# UVICORN WORKERS
####################################
# Number of uvicorn worker processes for handling requests
UVICORN_WORKERS = os.environ.get("UVICORN_WORKERS", "1")
try:
UVICORN_WORKERS = int(UVICORN_WORKERS)
if UVICORN_WORKERS < 1:
UVICORN_WORKERS = 1
except ValueError:
UVICORN_WORKERS = 1
log.info(f"Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}")
#################################### ####################################
# WEBUI_AUTH (Required for security) # WEBUI_AUTH (Required for security)
@ -341,11 +349,19 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
) )
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None
)
BYPASS_MODEL_ACCESS_CONTROL = ( BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
) )
WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get(
"WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None
)
#################################### ####################################
# WEBUI_SECRET_KEY # WEBUI_SECRET_KEY
#################################### ####################################
@ -385,6 +401,11 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@ -396,19 +417,88 @@ else:
except Exception: except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300 AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "" AIOHTTP_CLIENT_SESSION_SSL = (
os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true"
) )
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None "AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"),
)
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
else: else:
try: try:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int( AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST except Exception:
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA", "10"
)
if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == "":
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int(
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
) )
except Exception: except Exception:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5 AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true"
)
####################################
# SENTENCE TRANSFORMERS
####################################
SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "")
if SENTENCE_TRANSFORMERS_BACKEND == "":
SENTENCE_TRANSFORMERS_BACKEND = "torch"
SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch"
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
#################################### ####################################
# OFFLINE_MODE # OFFLINE_MODE
@ -418,3 +508,56 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE: if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1" os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
# Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
MAX_BODY_LOG_SIZE = 2048
# Comma separated list for urls to exclude from audit
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
","
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
####################################
# OPENTELEMETRY
####################################
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
) # e.g. key1=val1,key2=val2
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
####################################
# TOOLS/FUNCTIONS PIP OPTIONS
####################################
PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
####################################
# PROGRESSIVE WEB APP OPTIONS
####################################
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")

View file

@ -2,6 +2,7 @@ import logging
import sys import sys
import inspect import inspect
import json import json
import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator from typing import AsyncGenerator, Generator, Iterator
@ -27,7 +28,10 @@ from open_webui.socket.main import (
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
@ -52,12 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str): def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded function_module, _, _ = get_function_module_from_cache(request, pipe_id)
if pipe_id not in request.app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = request.app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id) valves = Functions.get_function_valves_by_id(pipe_id)
@ -76,10 +75,12 @@ async def get_function_models(request):
if hasattr(function_module, "pipes"): if hasattr(function_module, "pipes"):
sub_pipes = [] sub_pipes = []
# Check if pipes is a function or a list # Handle pipes being a list, sync function, or async function
try: try:
if callable(function_module.pipes): if callable(function_module.pipes):
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes() sub_pipes = function_module.pipes()
else: else:
sub_pipes = function_module.pipes sub_pipes = function_module.pipes
@ -220,6 +221,9 @@ async def generate_function_chat_completion(
extra_params = { extra_params = {
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
"__event_call__": __event_call__, "__event_call__": __event_call__,
"__chat_id__": metadata.get("chat_id", None),
"__session_id__": metadata.get("session_id", None),
"__message_id__": metadata.get("message_id", None),
"__task__": __task__, "__task__": __task__,
"__task_body__": __task_body__, "__task_body__": __task_body__,
"__files__": files, "__files__": files,
@ -249,8 +253,13 @@ async def generate_function_chat_completion(
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user) form_data = apply_model_system_prompt_to_body(
system, form_data, metadata, user
)
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id) function_module = get_function_module_by_id(request, pipe_id)

View file

@ -43,7 +43,7 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url): def register_connection(db_url):
db = connect(db_url, unquote_password=True) db = connect(db_url, unquote_user=True, unquote_password=True)
if isinstance(db, PostgresqlDatabase): if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee # Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True db.autoconnect = True
@ -51,7 +51,7 @@ def register_connection(db_url):
log.info("Connected to PostgreSQL database") log.info("Connected to PostgreSQL database")
# Get the connection details # Get the connection details
connection = parse(db_url, unquote_password=True) connection = parse(db_url, unquote_user=True, unquote_password=True)
# Use our custom database class that supports reconnection # Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection) db = ReconnectingPostgresqlDatabase(**connection)

View file

@ -17,6 +17,7 @@ from sqlalchemy import text
from typing import Optional from typing import Optional
from aiocache import cached from aiocache import cached
import aiohttp import aiohttp
import anyio.to_thread
import requests import requests
@ -39,12 +40,17 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette_compress import CompressMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse from starlette.responses import Response, StreamingResponse
from open_webui.utils import logger
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
from open_webui.utils.logger import start_logger
from open_webui.socket.main import ( from open_webui.socket.main import (
app as socket_app, app as socket_app,
periodic_usage_pool_cleanup, periodic_usage_pool_cleanup,
@ -60,6 +66,7 @@ from open_webui.routers import (
auths, auths,
channels, channels,
chats, chats,
notes,
folders, folders,
configs, configs,
groups, groups,
@ -81,25 +88,41 @@ from open_webui.routers.retrieval import (
get_rf, get_rf,
) )
from open_webui.internal.db import Session from open_webui.internal.db import Session, engine
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users from open_webui.models.users import UserModel, Users
from open_webui.models.chats import Chats
from open_webui.config import ( from open_webui.config import (
LICENSE_KEY,
# Ollama # Ollama
ENABLE_OLLAMA_API, ENABLE_OLLAMA_API,
OLLAMA_BASE_URLS, OLLAMA_BASE_URLS,
OLLAMA_API_CONFIGS, OLLAMA_API_CONFIGS,
# OpenAI # OpenAI
ENABLE_OPENAI_API, ENABLE_OPENAI_API,
ONEDRIVE_CLIENT_ID,
ONEDRIVE_SHAREPOINT_URL,
ONEDRIVE_SHAREPOINT_TENANT_ID,
OPENAI_API_BASE_URLS, OPENAI_API_BASE_URLS,
OPENAI_API_KEYS, OPENAI_API_KEYS,
OPENAI_API_CONFIGS, OPENAI_API_CONFIGS,
# Direct Connections # Direct Connections
ENABLE_DIRECT_CONNECTIONS, ENABLE_DIRECT_CONNECTIONS,
# Code Interpreter # Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE,
# Tool Server Configs
TOOL_SERVER_CONNECTIONS,
# Code Execution
ENABLE_CODE_EXECUTION,
CODE_EXECUTION_ENGINE,
CODE_EXECUTION_JUPYTER_URL,
CODE_EXECUTION_JUPYTER_AUTH,
CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
CODE_EXECUTION_JUPYTER_TIMEOUT,
ENABLE_CODE_INTERPRETER, ENABLE_CODE_INTERPRETER,
CODE_INTERPRETER_ENGINE, CODE_INTERPRETER_ENGINE,
CODE_INTERPRETER_PROMPT_TEMPLATE, CODE_INTERPRETER_PROMPT_TEMPLATE,
@ -107,6 +130,7 @@ from open_webui.config import (
CODE_INTERPRETER_JUPYTER_AUTH, CODE_INTERPRETER_JUPYTER_AUTH,
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
CODE_INTERPRETER_JUPYTER_TIMEOUT,
# Image # Image
AUTOMATIC1111_API_AUTH, AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL, AUTOMATIC1111_BASE_URL,
@ -125,11 +149,18 @@ from open_webui.config import (
IMAGE_STEPS, IMAGE_STEPS,
IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY, IMAGES_OPENAI_API_KEY,
IMAGES_GEMINI_API_BASE_URL,
IMAGES_GEMINI_API_KEY,
# Audio # Audio
AUDIO_STT_ENGINE, AUDIO_STT_ENGINE,
AUDIO_STT_MODEL, AUDIO_STT_MODEL,
AUDIO_STT_OPENAI_API_BASE_URL, AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY, AUDIO_STT_OPENAI_API_KEY,
AUDIO_STT_AZURE_API_KEY,
AUDIO_STT_AZURE_REGION,
AUDIO_STT_AZURE_LOCALES,
AUDIO_STT_AZURE_BASE_URL,
AUDIO_STT_AZURE_MAX_SPEAKERS,
AUDIO_TTS_API_KEY, AUDIO_TTS_API_KEY,
AUDIO_TTS_ENGINE, AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL, AUDIO_TTS_MODEL,
@ -138,59 +169,107 @@ from open_webui.config import (
AUDIO_TTS_SPLIT_ON, AUDIO_TTS_SPLIT_ON,
AUDIO_TTS_VOICE, AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION, AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_BASE_URL,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
PLAYWRIGHT_WS_URL,
PLAYWRIGHT_TIMEOUT,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
WEB_LOADER_ENGINE,
WHISPER_MODEL, WHISPER_MODEL,
WHISPER_VAD_FILTER,
WHISPER_LANGUAGE,
DEEPGRAM_API_KEY, DEEPGRAM_API_KEY,
WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
# Retrieval # Retrieval
RAG_TEMPLATE, RAG_TEMPLATE,
DEFAULT_RAG_TEMPLATE, DEFAULT_RAG_TEMPLATE,
RAG_FULL_CONTEXT,
BYPASS_EMBEDDING_AND_RETRIEVAL,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_RERANKING_ENGINE,
RAG_RERANKING_MODEL, RAG_RERANKING_MODEL,
RAG_EXTERNAL_RERANKER_URL,
RAG_EXTERNAL_RERANKER_API_KEY,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_BATCH_SIZE, RAG_EMBEDDING_BATCH_SIZE,
RAG_TOP_K,
RAG_TOP_K_RERANKER,
RAG_RELEVANCE_THRESHOLD, RAG_RELEVANCE_THRESHOLD,
RAG_HYBRID_BM25_WEIGHT,
RAG_ALLOWED_FILE_EXTENSIONS,
RAG_FILE_MAX_COUNT, RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE, RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY, RAG_OPENAI_API_KEY,
RAG_AZURE_OPENAI_BASE_URL,
RAG_AZURE_OPENAI_API_KEY,
RAG_AZURE_OPENAI_API_VERSION,
RAG_OLLAMA_BASE_URL, RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY, RAG_OLLAMA_API_KEY,
CHUNK_OVERLAP, CHUNK_OVERLAP,
CHUNK_SIZE, CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE, CONTENT_EXTRACTION_ENGINE,
DATALAB_MARKER_API_KEY,
DATALAB_MARKER_LANGS,
DATALAB_MARKER_SKIP_CACHE,
DATALAB_MARKER_FORCE_OCR,
DATALAB_MARKER_PAGINATE,
DATALAB_MARKER_STRIP_EXISTING_OCR,
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
DATALAB_MARKER_OUTPUT_FORMAT,
DATALAB_MARKER_USE_LLM,
EXTERNAL_DOCUMENT_LOADER_URL,
EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL, TIKA_SERVER_URL,
RAG_TOP_K, DOCLING_SERVER_URL,
DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY,
RAG_TEXT_SPLITTER, RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME, TIKTOKEN_ENCODING_NAME,
PDF_EXTRACT_IMAGES, PDF_EXTRACT_IMAGES,
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
YOUTUBE_LOADER_PROXY_URL, YOUTUBE_LOADER_PROXY_URL,
# Retrieval (Web Search) # Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE, ENABLE_WEB_SEARCH,
RAG_WEB_SEARCH_RESULT_COUNT, WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS, BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, BYPASS_WEB_SEARCH_WEB_LOADER,
WEB_SEARCH_RESULT_COUNT,
WEB_SEARCH_CONCURRENT_REQUESTS,
WEB_SEARCH_TRUST_ENV,
WEB_SEARCH_DOMAIN_FILTER_LIST,
JINA_API_KEY, JINA_API_KEY,
SEARCHAPI_API_KEY, SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE, SEARCHAPI_ENGINE,
SERPAPI_API_KEY, SERPAPI_API_KEY,
SERPAPI_ENGINE, SERPAPI_ENGINE,
SEARXNG_QUERY_URL, SEARXNG_QUERY_URL,
YACY_QUERY_URL,
YACY_USERNAME,
YACY_PASSWORD,
SERPER_API_KEY, SERPER_API_KEY,
SERPLY_API_KEY, SERPLY_API_KEY,
SERPSTACK_API_KEY, SERPSTACK_API_KEY,
SERPSTACK_HTTPS, SERPSTACK_HTTPS,
TAVILY_API_KEY, TAVILY_API_KEY,
TAVILY_EXTRACT_DEPTH,
BING_SEARCH_V7_ENDPOINT, BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY, BING_SEARCH_V7_SUBSCRIPTION_KEY,
BRAVE_SEARCH_API_KEY, BRAVE_SEARCH_API_KEY,
EXA_API_KEY, EXA_API_KEY,
PERPLEXITY_API_KEY,
SOUGOU_API_SID,
SOUGOU_API_SK,
KAGI_SEARCH_API_KEY, KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY, MOJEEK_SEARCH_API_KEY,
BOCHA_SEARCH_API_KEY, BOCHA_SEARCH_API_KEY,
@ -198,12 +277,19 @@ from open_webui.config import (
GOOGLE_PSE_ENGINE_ID, GOOGLE_PSE_ENGINE_ID,
GOOGLE_DRIVE_CLIENT_ID, GOOGLE_DRIVE_CLIENT_ID,
GOOGLE_DRIVE_API_KEY, GOOGLE_DRIVE_API_KEY,
ONEDRIVE_CLIENT_ID,
ONEDRIVE_SHAREPOINT_URL,
ONEDRIVE_SHAREPOINT_TENANT_ID,
ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ENABLE_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENABLE_GOOGLE_DRIVE_INTEGRATION, ENABLE_GOOGLE_DRIVE_INTEGRATION,
ENABLE_ONEDRIVE_INTEGRATION,
UPLOAD_DIR, UPLOAD_DIR,
EXTERNAL_WEB_SEARCH_URL,
EXTERNAL_WEB_SEARCH_API_KEY,
EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY,
# WebUI # WebUI
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_NAME, WEBUI_NAME,
@ -218,11 +304,15 @@ from open_webui.config import (
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
API_KEY_ALLOWED_ENDPOINTS, API_KEY_ALLOWED_ENDPOINTS,
ENABLE_CHANNELS, ENABLE_CHANNELS,
ENABLE_NOTES,
ENABLE_COMMUNITY_SHARING, ENABLE_COMMUNITY_SHARING,
ENABLE_MESSAGE_RATING, ENABLE_MESSAGE_RATING,
ENABLE_USER_WEBHOOKS,
ENABLE_EVALUATION_ARENA_MODELS, ENABLE_EVALUATION_ARENA_MODELS,
USER_PERMISSIONS, USER_PERMISSIONS,
DEFAULT_USER_ROLE, DEFAULT_USER_ROLE,
PENDING_USER_OVERLAY_CONTENT,
PENDING_USER_OVERLAY_TITLE,
DEFAULT_PROMPT_SUGGESTIONS, DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_MODELS, DEFAULT_MODELS,
DEFAULT_ARENA_MODEL, DEFAULT_ARENA_MODEL,
@ -249,6 +339,7 @@ from open_webui.config import (
LDAP_APP_PASSWORD, LDAP_APP_PASSWORD,
LDAP_USE_TLS, LDAP_USE_TLS,
LDAP_CA_CERT_FILE, LDAP_CA_CERT_FILE,
LDAP_VALIDATE_CERT,
LDAP_CIPHERS, LDAP_CIPHERS,
# Misc # Misc
ENV, ENV,
@ -259,6 +350,7 @@ from open_webui.config import (
DEFAULT_LOCALE, DEFAULT_LOCALE,
OAUTH_PROVIDERS, OAUTH_PROVIDERS,
WEBUI_URL, WEBUI_URL,
RESPONSE_WATERMARK,
# Admin # Admin
ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
@ -281,8 +373,14 @@ from open_webui.config import (
reset_config, reset_config,
) )
from open_webui.env import ( from open_webui.env import (
AUDIT_EXCLUDED_PATHS,
AUDIT_LOG_LEVEL,
CHANGELOG, CHANGELOG,
REDIS_URL,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
MAX_BODY_LOG_SIZE,
SAFE_MODE, SAFE_MODE,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
VERSION, VERSION,
@ -292,10 +390,14 @@ from open_webui.env import (
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
ENABLE_WEBSOCKET_SUPPORT, ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL, BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START, RESET_CONFIG_ON_START,
OFFLINE_MODE, OFFLINE_MODE,
ENABLE_OTEL,
EXTERNAL_PWA_MANIFEST_URL,
AIOHTTP_CLIENT_SESSION_SSL,
) )
@ -313,14 +415,24 @@ from open_webui.utils.middleware import process_chat_payload, process_chat_respo
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.utils.auth import ( from open_webui.utils.auth import (
get_license_data,
get_http_authorization_cred,
decode_token, decode_token,
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
) )
from open_webui.utils.oauth import oauth_manager from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py from open_webui.tasks import (
list_task_ids_by_chat_id,
stop_task,
list_tasks,
) # Import from tasks.py
from open_webui.utils.redis import get_sentinels_from_env
if SAFE_MODE: if SAFE_MODE:
print("SAFE MODE ENABLED") print("SAFE MODE ENABLED")
@ -348,15 +460,15 @@ class SPAStaticFiles(StaticFiles):
print( print(
rf""" rf"""
___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|_|
v{VERSION} - building the best open-source AI user interface. v{VERSION} - building the best AI user interface.
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
https://github.com/open-webui/open-webui https://github.com/open-webui/open-webui
""" """
@ -365,21 +477,56 @@ https://github.com/open-webui/open-webui
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
start_logger()
if RESET_CONFIG_ON_START: if RESET_CONFIG_ON_START:
reset_config() reset_config()
if LICENSE_KEY:
get_license_data(app, LICENSE_KEY)
# This should be blocking (sync) so functions are not deactivated on first /get_models calls
# when the first user lands on the / route.
log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies()
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = THREAD_POOL_SIZE
asyncio.create_task(periodic_usage_pool_cleanup()) asyncio.create_task(periodic_usage_pool_cleanup())
yield yield
app = FastAPI( app = FastAPI(
title="Open WebUI",
docs_url="/docs" if ENV == "dev" else None, docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None, redoc_url=None,
lifespan=lifespan, lifespan=lifespan,
) )
app.state.config = AppConfig() oauth_manager = OAuthManager(app)
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
)
app.state.WEBUI_NAME = WEBUI_NAME
app.state.LICENSE_METADATA = None
########################################
#
# OPENTELEMETRY
#
########################################
if ENABLE_OTEL:
from open_webui.utils.telemetry.setup import setup as setup_opentelemetry
setup_opentelemetry(app=app, db_engine=engine)
######################################## ########################################
@ -408,6 +555,15 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
app.state.OPENAI_MODELS = {} app.state.OPENAI_MODELS = {}
########################################
#
# TOOL SERVERS
#
########################################
app.state.config.TOOL_SERVER_CONNECTIONS = TOOL_SERVER_CONNECTIONS
app.state.TOOL_SERVERS = []
######################################## ########################################
# #
# DIRECT CONNECTIONS # DIRECT CONNECTIONS
@ -442,6 +598,11 @@ app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
app.state.config.RESPONSE_WATERMARK = RESPONSE_WATERMARK
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS app.state.config.BANNERS = WEBUI_BANNERS
@ -449,8 +610,10 @@ app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
app.state.config.ENABLE_NOTES = ENABLE_NOTES
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app.state.config.ENABLE_USER_WEBHOOKS = ENABLE_USER_WEBHOOKS
app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
@ -476,15 +639,22 @@ app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app.state.config.LDAP_VALIDATE_CERT = LDAP_VALIDATE_CERT
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.WEBUI_AUTH_SIGNOUT_REDIRECT_URL = WEBUI_AUTH_SIGNOUT_REDIRECT_URL
app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL
app.state.USER_COUNT = None
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {} app.state.TOOL_CONTENTS = {}
app.state.FUNCTIONS = {}
app.state.FUNCTION_CONTENTS = {}
######################################## ########################################
# #
@ -494,17 +664,41 @@ app.state.FUNCTIONS = {}
app.state.config.TOP_K = RAG_TOP_K app.state.config.TOP_K = RAG_TOP_K
app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.HYBRID_BM25_WEIGHT = RAG_HYBRID_BM25_WEIGHT
app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.DATALAB_MARKER_API_KEY = DATALAB_MARKER_API_KEY
app.state.config.DATALAB_MARKER_LANGS = DATALAB_MARKER_LANGS
app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE
app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR
app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE
app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTING_OCR
app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
)
app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM
app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
@ -515,12 +709,21 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_EXTERNAL_RERANKER_URL = RAG_EXTERNAL_RERANKER_URL
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = RAG_EXTERNAL_RERANKER_API_KEY
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL
app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY
app.state.config.RAG_AZURE_OPENAI_API_VERSION = RAG_AZURE_OPENAI_API_VERSION
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
@ -530,12 +733,24 @@ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.ENABLE_WEB_SEARCH = ENABLE_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE app.state.config.WEB_SEARCH_ENGINE = WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.WEB_SEARCH_RESULT_COUNT = WEB_SEARCH_RESULT_COUNT
app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.WEB_LOADER_ENGINE = WEB_LOADER_ENGINE
app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.YACY_QUERY_URL = YACY_QUERY_URL
app.state.config.YACY_USERNAME = YACY_USERNAME
app.state.config.YACY_PASSWORD = YACY_PASSWORD
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
@ -555,9 +770,20 @@ app.state.config.JINA_API_KEY = JINA_API_KEY
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY
app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL
app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL
app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
app.state.EMBEDDING_FUNCTION = None app.state.EMBEDDING_FUNCTION = None
app.state.ef = None app.state.ef = None
@ -574,7 +800,10 @@ try:
) )
app.state.rf = get_rf( app.state.rf = get_rf(
app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL, app.state.config.RAG_RERANKING_MODEL,
app.state.config.RAG_EXTERNAL_RERANKER_URL,
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
except Exception as e: except Exception as e:
@ -589,22 +818,45 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
( (
app.state.config.RAG_OPENAI_API_BASE_URL app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai" if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.RAG_OLLAMA_BASE_URL else (
app.state.config.RAG_OLLAMA_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
), ),
( (
app.state.config.RAG_OPENAI_API_KEY app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai" if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.RAG_OLLAMA_API_KEY else (
app.state.config.RAG_OLLAMA_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else app.state.config.RAG_AZURE_OPENAI_API_KEY
)
), ),
app.state.config.RAG_EMBEDDING_BATCH_SIZE, app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=(
app.state.config.RAG_AZURE_OPENAI_API_VERSION
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
) )
######################################## ########################################
# #
# CODE INTERPRETER # CODE EXECUTION
# #
######################################## ########################################
app.state.config.ENABLE_CODE_EXECUTION = ENABLE_CODE_EXECUTION
app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE
app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL
app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
@ -617,6 +869,7 @@ app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
) )
app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT
######################################## ########################################
# #
@ -631,6 +884,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
@ -659,8 +915,15 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.config.WHISPER_VAD_FILTER = WHISPER_VAD_FILTER
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
app.state.config.AUDIO_STT_AZURE_API_KEY = AUDIO_STT_AZURE_API_KEY
app.state.config.AUDIO_STT_AZURE_REGION = AUDIO_STT_AZURE_REGION
app.state.config.AUDIO_STT_AZURE_LOCALES = AUDIO_STT_AZURE_LOCALES
app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL
app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
@ -671,6 +934,7 @@ app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_BASE_URL = AUDIO_TTS_AZURE_SPEECH_BASE_URL
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
@ -733,7 +997,8 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Check for the specific watch path and the presence of 'v' parameter # Check for the specific watch path and the presence of 'v' parameter
if path.endswith("/watch") and "v" in query_params: if path.endswith("/watch") and "v" in query_params:
video_id = query_params["v"][0] # Extract the first 'v' parameter # Extract the first 'v' parameter
video_id = query_params["v"][0]
encoded_video_id = urlencode({"youtube": video_id}) encoded_video_id = urlencode({"youtube": video_id})
redirect_url = f"/?{encoded_video_id}" redirect_url = f"/?{encoded_video_id}"
return RedirectResponse(url=redirect_url) return RedirectResponse(url=redirect_url)
@ -744,6 +1009,7 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Add the middleware to the app # Add the middleware to the app
app.add_middleware(CompressMiddleware)
app.add_middleware(RedirectMiddleware) app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(SecurityHeadersMiddleware)
@ -759,6 +1025,10 @@ async def commit_session_after_request(request: Request, call_next):
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
start_time = int(time.time()) start_time = int(time.time())
request.state.token = get_http_authorization_cred(
request.headers.get("Authorization")
)
request.state.enable_api_key = app.state.config.ENABLE_API_KEY request.state.enable_api_key = app.state.config.ENABLE_API_KEY
response = await call_next(request) response = await call_next(request)
process_time = int(time.time()) - start_time process_time = int(time.time()) - start_time
@ -815,6 +1085,8 @@ app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"]) app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"])
app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
app.include_router(notes.router, prefix="/api/v1/notes", tags=["notes"])
app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"])
@ -832,6 +1104,19 @@ app.include_router(
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
try:
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
except ValueError as e:
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
audit_level = AuditLevel.NONE
if audit_level != AuditLevel.NONE:
app.add_middleware(
AuditLoggingMiddleware,
audit_level=audit_level,
excluded_paths=AUDIT_EXCLUDED_PATHS,
max_body_size=MAX_BODY_LOG_SIZE,
)
################################## ##################################
# #
# Chat Endpoints # Chat Endpoints
@ -864,14 +1149,29 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models return filtered_models
models = await get_all_models(request) all_models = await get_all_models(request, user=user)
models = []
for model in all_models:
# Filter out filter pipelines # Filter out filter pipelines
models = [ if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
model continue
for model in models
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" try:
model_tags = [
tag.get("name")
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
] ]
tags = [tag.get("name") for tag in model.get("tags", [])]
tags = list(set(model_tags + tags))
model["tags"] = [{"name": tag} for tag in tags]
except Exception as e:
log.debug(f"Error processing model tags: {e}")
model["tags"] = []
pass
models.append(model)
model_order_list = request.app.state.config.MODEL_ORDER_LIST model_order_list = request.app.state.config.MODEL_ORDER_LIST
if model_order_list: if model_order_list:
@ -893,7 +1193,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
@app.get("/api/models/base") @app.get("/api/models/base")
async def get_base_models(request: Request, user=Depends(get_admin_user)): async def get_base_models(request: Request, user=Depends(get_admin_user)):
models = await get_all_base_models(request) models = await get_all_base_models(request, user=user)
return {"data": models} return {"data": models}
@ -904,11 +1204,12 @@ async def chat_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request, user=user)
model_item = form_data.pop("model_item", {}) model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None) tasks = form_data.pop("background_tasks", None)
metadata = {}
try: try:
if not model_item.get("direct", False): if not model_item.get("direct", False):
model_id = form_data.get("model", None) model_id = form_data.get("model", None)
@ -936,11 +1237,13 @@ async def chat_completion(
"chat_id": form_data.pop("chat_id", None), "chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None), "message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None), "session_id": form_data.pop("session_id", None),
"filter_ids": form_data.pop("filter_ids", []),
"tool_ids": form_data.get("tool_ids", None), "tool_ids": form_data.get("tool_ids", None),
"tool_servers": form_data.pop("tool_servers", None),
"files": form_data.get("files", None), "files": form_data.get("files", None),
"features": form_data.get("features", None), "features": form_data.get("features", {}),
"variables": form_data.get("variables", None), "variables": form_data.get("variables", {}),
"model": model_info, "model": model,
"direct": model_item.get("direct", False), "direct": model_item.get("direct", False),
**( **(
{"function_calling": "native"} {"function_calling": "native"}
@ -958,11 +1261,21 @@ async def chat_completion(
form_data["metadata"] = metadata form_data["metadata"] = metadata
form_data, metadata, events = await process_chat_payload( form_data, metadata, events = await process_chat_payload(
request, form_data, metadata, user, model request, form_data, user, metadata, model
) )
except Exception as e: except Exception as e:
log.debug(f"Error processing chat payload: {e}") log.debug(f"Error processing chat payload: {e}")
if metadata.get("chat_id") and metadata.get("message_id"):
# Update the chat message with the error
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"error": {"content": str(e)},
},
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
@ -972,7 +1285,7 @@ async def chat_completion(
response = await chat_completion_handler(request, form_data, user) response = await chat_completion_handler(request, form_data, user)
return await process_chat_response( return await process_chat_response(
request, response, form_data, user, events, metadata, tasks request, response, form_data, user, metadata, model, events, tasks
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -1027,7 +1340,7 @@ async def chat_action(
@app.post("/api/tasks/stop/{task_id}") @app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)): async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
try: try:
result = await stop_task(task_id) # Use the function from tasks.py result = await stop_task(task_id)
return result return result
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@ -1035,7 +1348,19 @@ async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
@app.get("/api/tasks") @app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)): async def list_tasks_endpoint(user=Depends(get_verified_user)):
return {"tasks": list_tasks()} # Use the function from tasks.py return {"tasks": list_tasks()}
@app.get("/api/tasks/chat/{chat_id}")
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id(chat_id)
if chat is None or chat.user_id != user.id:
return {"task_ids": []}
task_ids = list_task_ids_by_chat_id(chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}
@app.post("/api/embeddings") @app.post("/api/embeddings")
@ -1066,15 +1391,16 @@ async def get_app_config(request: Request):
if data is not None and "id" in data: if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
onboarding = False
if user is None:
user_count = Users.get_num_users() user_count = Users.get_num_users()
onboarding = False
if user is None:
onboarding = user_count == 0 onboarding = user_count == 0
return { return {
**({"onboarding": True} if onboarding else {}), **({"onboarding": True} if onboarding else {}),
"status": True, "status": True,
"name": WEBUI_NAME, "name": app.state.WEBUI_NAME,
"version": VERSION, "version": VERSION,
"default_locale": str(DEFAULT_LOCALE), "default_locale": str(DEFAULT_LOCALE),
"oauth": { "oauth": {
@ -1095,15 +1421,19 @@ async def get_app_config(request: Request):
{ {
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
"enable_channels": app.state.config.ENABLE_CHANNELS, "enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_notes": app.state.config.ENABLE_NOTES,
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
"enable_code_execution": app.state.config.ENABLE_CODE_EXECUTION,
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER, "enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_user_webhooks": app.state.config.ENABLE_USER_WEBHOOKS,
"enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
} }
if user is not None if user is not None
else {} else {}
@ -1113,6 +1443,10 @@ async def get_app_config(request: Request):
{ {
"default_models": app.state.config.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"user_count": user_count,
"code": {
"engine": app.state.config.CODE_EXECUTION_ENGINE,
},
"audio": { "audio": {
"tts": { "tts": {
"engine": app.state.config.TTS_ENGINE, "engine": app.state.config.TTS_ENGINE,
@ -1132,6 +1466,24 @@ async def get_app_config(request: Request):
"client_id": GOOGLE_DRIVE_CLIENT_ID.value, "client_id": GOOGLE_DRIVE_CLIENT_ID.value,
"api_key": GOOGLE_DRIVE_API_KEY.value, "api_key": GOOGLE_DRIVE_API_KEY.value,
}, },
"onedrive": {
"client_id": ONEDRIVE_CLIENT_ID.value,
"sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value,
"sharepoint_tenant_id": ONEDRIVE_SHAREPOINT_TENANT_ID.value,
},
"ui": {
"pending_user_overlay_title": app.state.config.PENDING_USER_OVERLAY_TITLE,
"pending_user_overlay_content": app.state.config.PENDING_USER_OVERLAY_CONTENT,
"response_watermark": app.state.config.RESPONSE_WATERMARK,
},
"license_metadata": app.state.LICENSE_METADATA,
**(
{
"active_entries": app.state.USER_COUNT,
}
if user.role == "admin"
else {}
),
} }
if user is not None if user is not None
else {} else {}
@ -1175,7 +1527,8 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
timeout = aiohttp.ClientTimeout(total=1) timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get( async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest" "https://api.github.com/repos/open-webui/open-webui/releases/latest",
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
data = await response.json() data = await response.json()
@ -1209,7 +1562,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login") @app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request): async def oauth_login(provider: str, request: Request):
return await oauth_manager.handle_login(provider, request) return await oauth_manager.handle_login(request, provider)
# OAuth login logic is as follows: # OAuth login logic is as follows:
@ -1220,19 +1573,22 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken # - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback") @app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response): async def oauth_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(provider, request, response) return await oauth_manager.handle_callback(request, provider, response)
@app.get("/manifest.json") @app.get("/manifest.json")
async def get_manifest_json(): async def get_manifest_json():
if app.state.EXTERNAL_PWA_MANIFEST_URL:
return requests.get(app.state.EXTERNAL_PWA_MANIFEST_URL).json()
else:
return { return {
"name": WEBUI_NAME, "name": app.state.WEBUI_NAME,
"short_name": WEBUI_NAME, "short_name": app.state.WEBUI_NAME,
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.", "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
"start_url": "/", "start_url": "/",
"display": "standalone", "display": "standalone",
"background_color": "#343541", "background_color": "#343541",
"orientation": "natural", "orientation": "any",
"icons": [ "icons": [
{ {
"src": "/static/logo.png", "src": "/static/logo.png",
@ -1254,8 +1610,8 @@ async def get_manifest_json():
async def get_opensearch_xml(): async def get_opensearch_xml():
xml_content = rf""" xml_content = rf"""
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/"> <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
<ShortName>{WEBUI_NAME}</ShortName> <ShortName>{app.state.WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description> <Description>Search {app.state.WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding> <InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image> <Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/> <Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>

View file

@ -0,0 +1,33 @@
"""Add note table
Revision ID: 9f0c9cd09105
Revises: 3781e22d8b01
Create Date: 2025-05-03 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "9f0c9cd09105"
down_revision = "3781e22d8b01"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"note",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
def downgrade():
op.drop_table("note")

View file

@ -129,12 +129,16 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email)
if not user:
return None
try: try:
with get_db() as db: with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first() auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth: if auth:
if verify_password(password, auth.password): if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
return user return user
else: else:
return None return None
@ -155,8 +159,8 @@ class AuthsTable:
except Exception: except Exception:
return False return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}") log.info(f"authenticate_user_by_email: {email}")
try: try:
with get_db() as db: with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first() auth = db.query(Auth).filter_by(email=email, active=True).first()

View file

@ -1,3 +1,4 @@
import logging
import json import json
import time import time
import uuid import uuid
@ -5,7 +6,7 @@ from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
# Chat DB Schema # Chat DB Schema
#################### ####################
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base): class Chat(Base):
__tablename__ = "chat" __tablename__ = "chat"
@ -373,22 +377,47 @@ class ChatTable:
return False return False
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self,
user_id: str,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( query = db.query(Chat).filter_by(user_id=user_id, archived=True)
db.query(Chat)
.filter_by(user_id=user_id, archived=True) if filter:
.order_by(Chat.updated_at.desc()) query_key = filter.get("query")
# .limit(limit).offset(skip) if query_key:
.all() query = query.filter(Chat.title.ilike(f"%{query_key}%"))
)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id( def get_chat_list_by_user_id(
self, self,
user_id: str, user_id: str,
include_archived: bool = False, include_archived: bool = False,
filter: Optional[dict] = None,
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> list[ChatModel]: ) -> list[ChatModel]:
@ -397,6 +426,22 @@ class ChatTable:
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc()) query = query.order_by(Chat.updated_at.desc())
if skip: if skip:
@ -432,7 +477,7 @@ class ChatTable:
all_chats = query.all() all_chats = query.all()
# result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. # result has to be destructured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
return [ return [
ChatTitleIdResponse.model_validate( ChatTitleIdResponse.model_validate(
{ {
@ -538,7 +583,9 @@ class ChatTable:
search_text = search_text.lower().strip() search_text = search_text.lower().strip()
if not search_text: if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) return self.get_chat_list_by_user_id(
user_id, include_archived, filter={}, skip=skip, limit=limit
)
search_text_words = search_text.split(" ") search_text_words = search_text.split(" ")
@ -670,7 +717,7 @@ class ChatTable:
# Perform pagination at the SQL level # Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all() all_chats = query.offset(skip).limit(limit).all()
print(len(all_chats)) log.info(f"The number of chats: {len(all_chats)}")
# Validate and return chats # Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
@ -731,7 +778,7 @@ class ChatTable:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower() tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name) log.info(f"DB dialect name: {db.bind.dialect.name}")
if db.bind.dialect.name == "sqlite": if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field # SQLite JSON1 querying for tags within the meta JSON field
query = query.filter( query = query.filter(
@ -752,7 +799,7 @@ class ChatTable:
) )
all_chats = query.all() all_chats = query.all()
print("all_chats", all_chats) log.debug(f"all_chats: {all_chats}")
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name( def add_chat_tag_by_id_and_user_id_and_tag_name(
@ -810,7 +857,7 @@ class ChatTable:
count = query.count() count = query.count()
# Debugging output for inspection # Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count) log.info(f"Count of chats for tag '{tag_name}': {count}")
return count return count

View file

@ -118,7 +118,7 @@ class FeedbackTable:
else: else:
return None return None
except Exception as e: except Exception as e:
print(e) log.exception(f"Error creating a new feedback: {e}")
return None return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:

View file

@ -119,7 +119,7 @@ class FilesTable:
else: else:
return None return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") log.exception(f"Error inserting a new file: {e}")
return None return None
def get_file_by_id(self, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str) -> Optional[FileModel]:

View file

@ -9,6 +9,8 @@ from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.access_control import get_permissions
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -82,7 +84,7 @@ class FolderTable:
else: else:
return None return None
except Exception as e: except Exception as e:
print(e) log.exception(f"Error inserting a new folder: {e}")
return None return None
def get_folder_by_id_and_user_id( def get_folder_by_id_and_user_id(
@ -234,13 +236,16 @@ class FolderTable:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_folder_by_id_and_user_id(
self, id: str, user_id: str, delete_chats=True
) -> bool:
try: try:
with get_db() as db: with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder: if not folder:
return False return False
if delete_chats:
# Delete all chats in the folder # Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id) Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
@ -250,9 +255,11 @@ class FolderTable:
folder.id, user_id folder.id, user_id
) )
for folder_child in folder_children: for folder_child in folder_children:
if delete_chats:
Chats.delete_chats_by_user_id_and_folder_id( Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id user_id, folder_child.id
) )
delete_children(folder_child) delete_children(folder_child)
folder = db.query(Folder).filter_by(id=folder_child.id).first() folder = db.query(Folder).filter_by(id=folder_child.id).first()

View file

@ -105,9 +105,57 @@ class FunctionsTable:
else: else:
return None return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") log.exception(f"Error creating a new function: {e}")
return None return None
def sync_functions(
self, user_id: str, functions: list[FunctionModel]
) -> list[FunctionModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try:
with get_db() as db:
# Get existing functions
existing_functions = db.query(Function).all()
existing_ids = {func.id for func in existing_functions}
# Prepare a set of new function IDs
new_function_ids = {func.id for func in functions}
# Update or insert functions
for func in functions:
if func.id in existing_ids:
db.query(Function).filter_by(id=func.id).update(
{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
else:
new_func = Function(
**{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
db.add(new_func)
# Remove functions that are no longer present
for func in existing_functions:
if func.id not in new_function_ids:
db.delete(func)
db.commit()
return [
FunctionModel.model_validate(func)
for func in db.query(Function).all()
]
except Exception as e:
log.exception(f"Error syncing functions for user {user_id}: {e}")
return []
def get_function_by_id(self, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
with get_db() as db: with get_db() as db:
@ -170,7 +218,7 @@ class FunctionsTable:
function = db.get(Function, id) function = db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") log.exception(f"Error getting function valves by id {id}: {e}")
return None return None
def update_function_valves_by_id( def update_function_valves_by_id(
@ -202,7 +250,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"].get(id, {}) return user_settings["functions"]["valves"].get(id, {})
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") log.exception(
f"Error getting user values by id {id} and user id {user_id}: {e}"
)
return None return None
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
@ -225,7 +275,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"][id] return user_settings["functions"]["valves"][id]
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:

View file

@ -207,5 +207,43 @@ class GroupTable:
except Exception: except Exception:
return False return False
def sync_user_groups_by_group_names(
self, user_id: str, group_names: list[str]
) -> bool:
with get_db() as db:
try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups]
# Remove user from groups not in the new list
existing_groups = self.get_groups_by_member_id(user_id)
for group in existing_groups:
if group.id not in group_ids:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
# Add user to new groups
for group in groups:
if user_id not in group.user_ids:
group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except Exception as e:
log.exception(e)
return False
Groups = GroupTable() Groups = GroupTable()

View file

@ -63,14 +63,15 @@ class MemoriesTable:
else: else:
return None return None
def update_memory_by_id( def update_memory_by_id_and_user_id(
self, self,
id: str, id: str,
user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id).update( db.query(Memory).filter_by(id=id, user_id=user_id).update(
{"content": content, "updated_at": int(time.time())} {"content": content, "updated_at": int(time.time())}
) )
db.commit() db.commit()

5
backend/open_webui/models/models.py Normal file → Executable file
View file

@ -166,7 +166,7 @@ class ModelsTable:
else: else:
return None return None
except Exception as e: except Exception as e:
print(e) log.exception(f"Failed to insert a new model: {e}")
return None return None
def get_all_models(self) -> list[ModelModel]: def get_all_models(self) -> list[ModelModel]:
@ -246,8 +246,7 @@ class ModelsTable:
db.refresh(model) db.refresh(model)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
print(e) log.exception(f"Failed to update the model by id {id}: {e}")
return None return None
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:

View file

@ -0,0 +1,135 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.utils.access_control import has_access
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Note DB Schema
####################
class Note(Base):
__tablename__ = "note"
id = Column(Text, primary_key=True)
user_id = Column(Text)
title = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class NoteModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
title: str
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class NoteForm(BaseModel):
title: str
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class NoteUserResponse(NoteModel):
user: Optional[UserResponse] = None
class NoteTable:
def insert_new_note(
self,
form_data: NoteForm,
user_id: str,
) -> Optional[NoteModel]:
with get_db() as db:
note = NoteModel(
**{
"id": str(uuid.uuid4()),
"user_id": user_id,
**form_data.model_dump(),
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_note = Note(**note.model_dump())
db.add(new_note)
db.commit()
return note
def get_notes(self) -> list[NoteModel]:
with get_db() as db:
notes = db.query(Note).order_by(Note.updated_at.desc()).all()
return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[NoteModel]:
notes = self.get_notes()
return [
note
for note in notes
if note.user_id == user_id
or has_access(user_id, permission, note.access_control)
]
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None
def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
if not note:
return None
note.title = form_data.title
note.data = form_data.data
note.meta = form_data.meta
note.access_control = form_data.access_control
note.updated_at = int(time.time_ns())
db.commit()
return NoteModel.model_validate(note) if note else None
def delete_note_by_id(self, id: str):
with get_db() as db:
db.query(Note).filter(Note.id == id).delete()
db.commit()
return True
Notes = NoteTable()

View file

@ -61,7 +61,7 @@ class TagTable:
else: else:
return None return None
except Exception as e: except Exception as e:
print(e) log.exception(f"Error inserting a new tag: {e}")
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(

View file

@ -131,7 +131,7 @@ class ToolsTable:
else: else:
return None return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") log.exception(f"Error creating a new tool: {e}")
return None return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
@ -175,7 +175,7 @@ class ToolsTable:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") log.exception(f"Error getting tool valves by id {id}: {e}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
@ -204,7 +204,9 @@ class ToolsTable:
return user_settings["tools"]["valves"].get(id, {}) return user_settings["tools"]["valves"].get(id, {})
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") log.exception(
f"Error getting user values by id {id} and user_id {user_id}: {e}"
)
return None return None
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
@ -227,7 +229,9 @@ class ToolsTable:
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:

View file

@ -10,6 +10,8 @@ from open_webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import or_
#################### ####################
# User DB Schema # User DB Schema
@ -67,6 +69,11 @@ class UserModel(BaseModel):
#################### ####################
class UserListResponse(BaseModel):
users: list[UserModel]
total: int
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
name: str name: str
@ -160,11 +167,63 @@ class UsersTable:
return None return None
def get_users( def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None self,
) -> list[UserModel]: filter: Optional[dict] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> UserListResponse:
with get_db() as db: with get_db() as db:
query = db.query(User)
query = db.query(User).order_by(User.created_at.desc()) if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
)
)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "email":
if direction == "asc":
query = query.order_by(User.email.asc())
else:
query = query.order_by(User.email.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(User.created_at.asc())
else:
query = query.order_by(User.created_at.desc())
elif order_by == "last_active_at":
if direction == "asc":
query = query.order_by(User.last_active_at.asc())
else:
query = query.order_by(User.last_active_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(User.updated_at.asc())
else:
query = query.order_by(User.updated_at.desc())
elif order_by == "role":
if direction == "asc":
query = query.order_by(User.role.asc())
else:
query = query.order_by(User.role.desc())
else:
query = query.order_by(User.created_at.desc())
if skip: if skip:
query = query.offset(skip) query = query.offset(skip)
@ -172,8 +231,10 @@ class UsersTable:
query = query.limit(limit) query = query.limit(limit)
users = query.all() users = query.all()
return {
return [UserModel.model_validate(user) for user in users] "users": [UserModel.model_validate(user) for user in users],
"total": db.query(User).count(),
}
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]: def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db: with get_db() as db:
@ -330,5 +391,13 @@ class UsersTable:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users] return [user.id for user in users]
def get_super_admin_user(self) -> Optional[UserModel]:
with get_db() as db:
user = db.query(User).filter_by(role="admin").first()
if user:
return UserModel.model_validate(user)
else:
return None
Users = UsersTable() Users = UsersTable()

View file

@ -0,0 +1,251 @@
import os
import time
import requests
import logging
import json
from typing import List, Optional
from langchain_core.documents import Document
from fastapi import HTTPException, status
log = logging.getLogger(__name__)
class DatalabMarkerLoader:
def __init__(
self,
file_path: str,
api_key: str,
langs: Optional[str] = None,
use_llm: bool = False,
skip_cache: bool = False,
force_ocr: bool = False,
paginate: bool = False,
strip_existing_ocr: bool = False,
disable_image_extraction: bool = False,
output_format: str = None,
):
self.file_path = file_path
self.api_key = api_key
self.langs = langs
self.use_llm = use_llm
self.skip_cache = skip_cache
self.force_ocr = force_ocr
self.paginate = paginate
self.strip_existing_ocr = strip_existing_ocr
self.disable_image_extraction = disable_image_extraction
self.output_format = output_format
def _get_mime_type(self, filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower()
mime_map = {
"pdf": "application/pdf",
"xls": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"ods": "application/vnd.oasis.opendocument.spreadsheet",
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"odt": "application/vnd.oasis.opendocument.text",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"odp": "application/vnd.oasis.opendocument.presentation",
"html": "text/html",
"epub": "application/epub+zip",
"png": "image/png",
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"webp": "image/webp",
"gif": "image/gif",
"tiff": "image/tiff",
}
return mime_map.get(ext, "application/octet-stream")
def check_marker_request_status(self, request_id: str) -> dict:
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
headers = {"X-Api-Key": self.api_key}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
result = response.json()
log.info(f"Marker API status check for request {request_id}: {result}")
return result
except requests.HTTPError as e:
log.error(f"Error checking Marker request status: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to check Marker request: {e}",
)
except ValueError as e:
log.error(f"Invalid JSON checking Marker request: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
)
def load(self) -> List[Document]:
url = "https://www.datalab.to/api/v1/marker"
filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key}
form_data = {
"langs": self.langs,
"use_llm": str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(),
"output_format": self.output_format,
}
log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
)
try:
with open(self.file_path, "rb") as f:
files = {"file": (filename, f, mime_type)}
response = requests.post(
url, data=form_data, files=files, headers=headers
)
response.raise_for_status()
result = response.json()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {e}",
)
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
)
except Exception as e:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
if not result.get("success"):
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
)
check_url = result.get("request_check_url")
request_id = result.get("request_id")
if not check_url:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
)
for _ in range(300): # Up to 10 minutes
time.sleep(2)
try:
poll_response = requests.get(check_url, headers=headers)
poll_response.raise_for_status()
poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
status_val = poll_result.get("status")
success_val = poll_result.get("success")
if status_val == "complete":
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}",
)
else:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
)
if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
)
content_key = self.output_format.lower()
raw_content = poll_result.get(content_key)
if content_key == "json":
full_text = json.dumps(raw_content, indent=2)
elif content_key in {"markdown", "html"}:
full_text = str(raw_content).strip()
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {self.output_format}",
)
if not full_text:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Datalab Marker returned empty content",
)
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
file_ext = file_ext_map.get(content_key, "txt")
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
output_path = os.path.join(marker_output_dir, output_filename)
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(full_text)
log.info(f"Saved Marker output to: {output_path}")
except Exception as e:
log.warning(f"Failed to write marker output to disk: {e}")
metadata = {
"source": filename,
"output_format": poll_result.get("output_format", self.output_format),
"page_count": poll_result.get("page_count", 0),
"processed_with_llm": self.use_llm,
"request_id": request_id or "",
}
images = poll_result.get("images", {})
if images:
metadata["image_count"] = len(images)
metadata["images"] = json.dumps(list(images.keys()))
for k, v in metadata.items():
if isinstance(v, (dict, list)):
metadata[k] = json.dumps(v)
elif v is None:
metadata[k] = ""
return [Document(page_content=full_text, metadata=metadata)]

View file

@ -0,0 +1,58 @@
import requests
import logging
from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalDocumentLoader(BaseLoader):
def __init__(
self,
file_path,
url: str,
api_key: str,
mime_type=None,
**kwargs,
) -> None:
self.url = url
self.api_key = api_key
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
headers = {}
if self.mime_type is not None:
headers["Content-Type"] = self.mime_type
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
url = self.url
if url.endswith("/"):
url = url[:-1]
r = requests.put(f"{url}/process", data=data, headers=headers)
if r.ok:
res = r.json()
if res:
return [
Document(
page_content=res.get("page_content"),
metadata=res.get("metadata"),
)
]
else:
raise Exception("Error loading document: No content returned")
else:
raise Exception(f"Error loading document: {r.status_code} {r.text}")

View file

@ -0,0 +1,53 @@
import requests
import logging
from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalWebLoader(BaseLoader):
def __init__(
self,
web_paths: Union[str, List[str]],
external_url: str,
external_api_key: str,
continue_on_failure: bool = True,
**kwargs,
) -> None:
self.external_url = external_url
self.external_api_key = external_api_key
self.urls = web_paths if isinstance(web_paths, list) else [web_paths]
self.continue_on_failure = continue_on_failure
def lazy_load(self) -> Iterator[Document]:
batch_size = 20
for i in range(0, len(self.urls), batch_size):
urls = self.urls[i : i + batch_size]
try:
response = requests.post(
self.external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
"Authorization": f"Bearer {self.external_api_key}",
},
json={
"urls": urls,
},
)
response.raise_for_status()
results = response.json()
for result in results:
yield Document(
page_content=result.get("page_content", ""),
metadata=result.get("metadata", {}),
)
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {urls}: {e}")
else:
raise e

View file

@ -4,6 +4,7 @@ import ftfy
import sys import sys
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
BSHTMLLoader, BSHTMLLoader,
CSVLoader, CSVLoader,
Docx2txtLoader, Docx2txtLoader,
@ -19,6 +20,13 @@ from langchain_community.document_loaders import (
YoutubeLoader, YoutubeLoader,
) )
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@ -76,15 +84,18 @@ known_source_ext = [
"jsx", "jsx",
"hs", "hs",
"lhs", "lhs",
"json",
] ]
class TikaLoader: class TikaLoader:
def __init__(self, url, file_path, mime_type=None): def __init__(self, url, file_path, mime_type=None, extract_images=None):
self.url = url self.url = url
self.file_path = file_path self.file_path = file_path
self.mime_type = mime_type self.mime_type = mime_type
self.extract_images = extract_images
def load(self) -> list[Document]: def load(self) -> list[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, "rb") as f:
data = f.read() data = f.read()
@ -94,6 +105,9 @@ class TikaLoader:
else: else:
headers = {} headers = {}
if self.extract_images == True:
headers["X-Tika-PDFextractInlineImages"] = "true"
endpoint = self.url endpoint = self.url
if not endpoint.endswith("/"): if not endpoint.endswith("/"):
endpoint += "/" endpoint += "/"
@ -103,7 +117,7 @@ class TikaLoader:
if r.ok: if r.ok:
raw_metadata = r.json() raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>") text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
if "Content-Type" in raw_metadata: if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"] headers["Content-Type"] = raw_metadata["Content-Type"]
@ -115,6 +129,68 @@ class TikaLoader:
raise Exception(f"Error calling Tika: {r.reason}") raise Exception(f"Error calling Tika: {r.reason}")
class DoclingLoader:
def __init__(self, url, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/")
self.file_path = file_path
self.mime_type = mime_type
self.params = params or {}
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
files = {
"files": (
self.file_path,
f,
self.mime_type or "application/octet-stream",
)
}
params = {
"image_export_mode": "placeholder",
"table_mode": "accurate",
}
if self.params:
if self.params.get("do_picture_classification"):
params["do_picture_classification"] = self.params.get(
"do_picture_classification"
)
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
params["ocr_engine"] = self.params.get("ocr_engine")
params["ocr_lang"] = [
lang.strip()
for lang in self.params.get("ocr_lang").split(",")
if lang.strip()
]
endpoint = f"{self.url}/v1alpha/convert/file"
r = requests.post(endpoint, files=files, data=params)
if r.ok:
result = r.json()
document_data = result.get("document", {})
text = document_data.get("md_content", "<No text content found>")
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text)
return [Document(page_content=text, metadata=metadata)]
else:
error_msg = f"Error calling Docling API: {r.reason}"
if r.text:
try:
error_data = r.json()
if "detail" in error_data:
error_msg += f" - {error_data['detail']}"
except Exception:
error_msg += f" - {r.text}"
raise Exception(f"Error calling Docling: {error_msg}")
class Loader: class Loader:
def __init__(self, engine: str = "", **kwargs): def __init__(self, engine: str = "", **kwargs):
self.engine = engine self.engine = engine
@ -133,19 +209,132 @@ class Loader:
for doc in docs for doc in docs
] ]
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
)
def _get_loader(self, filename: str, file_content_type: str, file_path: str): def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower() file_ext = filename.split(".")[-1].lower()
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): if (
if file_ext in known_source_ext or ( self.engine == "external"
file_content_type and file_content_type.find("text/") >= 0 and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
): ):
loader = ExternalDocumentLoader(
file_path=file_path,
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
mime_type=file_content_type,
)
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:
loader = TikaLoader( loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"), url=self.kwargs.get("TIKA_SERVER_URL"),
file_path=file_path, file_path=file_path,
mime_type=file_content_type, mime_type=file_content_type,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
)
elif (
self.engine == "datalab_marker"
and self.kwargs.get("DATALAB_MARKER_API_KEY")
and file_ext
in [
"pdf",
"xls",
"xlsx",
"ods",
"doc",
"docx",
"odt",
"ppt",
"pptx",
"odp",
"html",
"epub",
"png",
"jpeg",
"jpg",
"webp",
"gif",
"tiff",
]
):
loader = DatalabMarkerLoader(
file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
strip_existing_ocr=self.kwargs.get(
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
),
disable_image_extraction=self.kwargs.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
),
output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
),
)
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_classification": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
and (
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
or file_content_type
in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
]
)
):
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
)
elif (
self.engine == "mistral_ocr"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
)
elif (
self.engine == "external"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
) )
else: else:
if file_ext == "pdf": if file_ext == "pdf":
@ -153,7 +342,7 @@ class Loader:
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES") file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
) )
elif file_ext == "csv": elif file_ext == "csv":
loader = CSVLoader(file_path) loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_ext == "rst": elif file_ext == "rst":
loader = UnstructuredRSTLoader(file_path, mode="elements") loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml": elif file_ext == "xml":
@ -182,9 +371,7 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path) loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg": elif file_ext == "msg":
loader = OutlookMessageLoader(file_path) loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or ( elif self._is_text_file(file_ext, file_content_type):
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)

View file

@ -0,0 +1,633 @@
import requests
import aiohttp
import asyncio
import logging
import os
import sys
import time
from typing import List, Dict, Any
from contextlib import asynccontextmanager
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MistralLoader:
"""
Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API.
"""
BASE_API_URL = "https://api.mistral.ai/v1"
def __init__(
self,
api_key: str,
file_path: str,
timeout: int = 300, # 5 minutes default
max_retries: int = 3,
enable_debug_logging: bool = False,
):
"""
Initializes the loader with enhanced features.
Args:
api_key: Your Mistral API key.
file_path: The local path to the PDF file to process.
timeout: Request timeout in seconds.
max_retries: Maximum number of retry attempts.
enable_debug_logging: Enable detailed debug logs.
"""
if not api_key:
raise ValueError("API key cannot be empty.")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found at {file_path}")
self.api_key = api_key
self.file_path = file_path
self.timeout = timeout
self.max_retries = max_retries
self.debug = enable_debug_logging
# Pre-compute file info for performance
self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path)
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0",
}
def _debug_log(self, message: str, *args) -> None:
"""Conditional debug logging for performance."""
if self.debug:
log.debug(message, *args)
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
"""Checks response status and returns JSON content."""
try:
response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
# Handle potential empty responses for certain successful requests (e.g., DELETE)
if response.status_code == 204 or not response.content:
return {} # Return empty dict if no content
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
raise
except requests.exceptions.RequestException as req_err:
log.error(f"Request exception occurred: {req_err}")
raise
except ValueError as json_err: # Includes JSONDecodeError
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
raise # Re-raise after logging
async def _handle_response_async(
self, response: aiohttp.ClientResponse
) -> Dict[str, Any]:
"""Async version of response handling with better error info."""
try:
response.raise_for_status()
# Check content type
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
if response.status == 204:
return {}
text = await response.text()
raise ValueError(
f"Unexpected content type: {content_type}, body: {text[:200]}..."
)
return await response.json()
except aiohttp.ClientResponseError as e:
error_text = await response.text() if response else "No response"
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
raise
except aiohttp.ClientError as e:
log.error(f"Client error: {e}")
raise
except Exception as e:
log.error(f"Unexpected error processing response: {e}")
raise
def _retry_request_sync(self, request_func, *args, **kwargs):
"""Synchronous retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return request_func(*args, **kwargs)
except (requests.exceptions.RequestException, Exception) as e:
if attempt == self.max_retries - 1:
raise
wait_time = (2**attempt) + 0.5
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs):
"""Async retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return await request_func(*args, **kwargs)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == self.max_retries - 1:
raise
wait_time = (2**attempt) + 0.5
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
def _upload_file(self) -> str:
"""Uploads the file to Mistral for OCR processing (sync version)."""
log.info("Uploading file to Mistral API")
url = f"{self.BASE_API_URL}/files"
file_name = os.path.basename(self.file_path)
def upload_request():
with open(self.file_path, "rb") as f:
files = {"file": (file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
response = requests.post(
url,
headers=self.headers,
files=files,
data=data,
timeout=self.timeout,
)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
return file_id
except Exception as e:
log.error(f"Failed to upload file: {e}")
raise
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency."""
url = f"{self.BASE_API_URL}/files"
async def upload_request():
# Create multipart writer for streaming upload
writer = aiohttp.MultipartWriter("form-data")
# Add purpose field
purpose_part = writer.append("ocr")
purpose_part.set_content_disposition("form-data", name="purpose")
# Add file part with streaming
file_part = writer.append_payload(
aiohttp.streams.FilePayload(
self.file_path,
filename=self.file_name,
content_type="application/pdf",
)
)
file_part.set_content_disposition(
"form-data", name="file", filename=self.file_name
)
self._debug_log(
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
)
async with session.post(
url,
data=writer,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
return file_id
def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"}
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.timeout
)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
log.info("Signed URL received.")
return signed_url
except Exception as e:
log.error(f"Failed to get signed URL: {e}")
raise
async def _get_signed_url_async(
self, session: aiohttp.ClientSession, file_id: str
) -> str:
"""Async signed URL retrieval."""
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
headers = {**self.headers, "Accept": "application/json"}
async def url_request():
self._debug_log(f"Getting signed URL for file ID: {file_id}")
async with session.get(
url,
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
self._debug_log("Signed URL received successfully")
return signed_url
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API")
url = f"{self.BASE_API_URL}/ocr"
ocr_headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
},
"include_image_base64": False,
}
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.timeout
)
return self._handle_response(response)
try:
ocr_response = self._retry_request_sync(ocr_request)
log.info("OCR processing done.")
self._debug_log("OCR response: %s", ocr_response)
return ocr_response
except Exception as e:
log.error(f"Failed during OCR processing: {e}")
raise
async def _process_ocr_async(
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
"""Async OCR processing with timing metrics."""
url = f"{self.BASE_API_URL}/ocr"
headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
},
"include_image_base64": False,
}
async def ocr_request():
log.info("Starting OCR processing via Mistral API")
start_time = time.time()
async with session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
ocr_response = await self._handle_response_async(response)
processing_time = time.time() - start_time
log.info(f"OCR processing completed in {processing_time:.2f}s")
return ocr_response
return await self._retry_request_async(ocr_request)
def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}"
try:
response = requests.delete(url, headers=self.headers, timeout=30)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
except Exception as e:
# Log error but don't necessarily halt execution if deletion fails
log.error(f"Failed to delete file ID {file_id}: {e}")
async def _delete_file_async(
self, session: aiohttp.ClientSession, file_id: str
) -> None:
"""Async file deletion with error tolerance."""
try:
async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}")
async with session.delete(
url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=30
), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
await self._retry_request_async(delete_request)
self._debug_log(f"File {file_id} deleted successfully")
except Exception as e:
# Don't fail the entire process if cleanup fails
log.warning(f"Failed to delete file ID {file_id}: {e}")
@asynccontextmanager
async def _get_session(self):
"""Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector(
limit=10, # Total connection limit
limit_per_host=5, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
async with aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=self.timeout),
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
) as session:
yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata."""
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [
Document(
page_content="No text content found", metadata={"error": "no_pages"}
)
]
documents = []
total_pages = len(pages_data)
skipped_pages = 0
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is not None and page_index is not None:
# Clean up content efficiently
cleaned_content = (
page_content.strip()
if isinstance(page_content, str)
else str(page_content)
)
if cleaned_content: # Only add non-empty pages
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
},
)
)
else:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
else:
skipped_pages += 1
self._debug_log(
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
)
if skipped_pages > 0:
log.info(
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
)
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
return [
Document(
page_content="No valid text content found in document",
metadata={"error": "no_valid_pages", "total_pages": total_pages},
)
]
return documents
def load(self) -> List[Document]:
"""
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
Synchronous version for backward compatibility.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
# 1. Upload file
file_id = self._upload_file()
# 2. Get Signed URL
signed_url = self._get_signed_url(file_id)
# 3. Process OCR
ocr_response = self._process_ocr(signed_url)
# 4. Process results
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
)
# Return an error document on failure
return [
Document(
page_content=f"Error during processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Delete file (attempt even if prior steps failed after upload)
if file_id:
try:
self._delete_file(file_id)
except Exception as del_e:
# Log deletion error, but don't overwrite original error if one occurred
log.error(
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
)
async def load_async(self) -> List[Document]:
"""
Asynchronous OCR workflow execution with optimized performance.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
async with self._get_session() as session:
# 1. Upload file with streaming
file_id = await self._upload_file_async(session)
# 2. Get signed URL
signed_url = await self._get_signed_url_async(session, file_id)
# 3. Process OCR
ocr_response = await self._process_ocr_async(session, signed_url)
# 4. Process results
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
return [
Document(
page_content=f"Error during OCR processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Cleanup - always attempt file deletion
if file_id:
try:
async with self._get_session() as session:
await self._delete_file_async(session, file_id)
except Exception as cleanup_error:
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
) -> List[List[Document]]:
"""
Process multiple files concurrently for maximum performance.
Args:
loaders: List of MistralLoader instances
Returns:
List of document lists, one for each loader
"""
if not loaders:
return []
log.info(f"Starting concurrent processing of {len(loaders)} files")
start_time = time.time()
# Process all files concurrently
tasks = [loader.load_async() for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
log.error(f"File {i} failed: {result}")
processed_results.append(
[
Document(
page_content=f"Error processing file: {result}",
metadata={
"error": "batch_processing_failed",
"file_index": i,
},
)
]
)
else:
processed_results.append(result)
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
log.info(
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
)
return processed_results

View file

@ -0,0 +1,93 @@
import requests
import logging
from typing import Iterator, List, Literal, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class TavilyLoader(BaseLoader):
"""Extract web page content from URLs using Tavily Extract API.
This is a LangChain document loader that uses Tavily's Extract API to
retrieve content from web pages and return it as Document objects.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
extract_depth: Depth of extraction, either "basic" or "advanced".
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
def __init__(
self,
urls: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
continue_on_failure: bool = True,
) -> None:
"""Initialize Tavily Extract client.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
include_images: Whether to include images in the extraction.
extract_depth: Depth of extraction, either "basic" or "advanced".
advanced extraction retrieves more data, including tables and
embedded content, with higher success but may increase latency.
basic costs 1 credit per 5 successful URL extractions,
advanced costs 2 credits per 5 successful URL extractions.
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
if not urls:
raise ValueError("At least one URL must be provided.")
self.api_key = api_key
self.urls = urls if isinstance(urls, list) else [urls]
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.api_url = "https://api.tavily.com/extract"
def lazy_load(self) -> Iterator[Document]:
"""Extract and yield documents from the URLs using Tavily Extract API."""
batch_size = 20
for i in range(0, len(self.urls), batch_size):
batch_urls = self.urls[i : i + batch_size]
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
# Use string for single URL, array for multiple URLs
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
# Make the API call
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
# Process successful results
for result in response_data.get("results", []):
url = result.get("url", "")
content = result.get("raw_content", "")
if not content:
log.warning(f"No content extracted from {url}")
continue
# Add URLs as metadata
metadata = {"source": url}
yield Document(
page_content=content,
metadata=metadata,
)
for failed in response_data.get("failed_results", []):
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
log.error(f"Failed to extract content from {url}: {error}")
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {batch_urls}: {e}")
else:
raise e

View file

@ -62,12 +62,17 @@ class YoutubeLoader:
_video_id = _parse_video_id(video_id) _video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id} self._metadata = {"source": video_id}
self.language = language
self.proxy_url = proxy_url self.proxy_url = proxy_url
# Ensure language is a list
if isinstance(language, str): if isinstance(language, str):
self.language = [language] self.language = [language]
else: else:
self.language = language self.language = list(language)
# Add English as fallback if not already in the list
if "en" not in self.language:
self.language.append("en")
def load(self) -> List[Document]: def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects.""" """Load YouTube transcripts into `Document` objects."""
@ -101,17 +106,31 @@ class YoutubeLoader:
log.exception("Loading YouTube transcript failed") log.exception("Loading YouTube transcript failed")
return [] return []
# Try each language in order of priority
for lang in self.language:
try: try:
transcript = transcript_list.find_transcript(self.language) transcript = transcript_list.find_transcript([lang])
except NoTranscriptFound: log.debug(f"Found transcript for language '{lang}'")
transcript = transcript_list.find_transcript(["en"])
transcript_pieces: List[Dict[str, Any]] = transcript.fetch() transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript_text = " ".join(
transcript = " ".join(
map( map(
lambda transcript_piece: transcript_piece["text"].strip(" "), lambda transcript_piece: transcript_piece.text.strip(" "),
transcript_pieces, transcript_pieces,
) )
) )
return [Document(page_content=transcript, metadata=self._metadata)] return [Document(page_content=transcript_text, metadata=self._metadata)]
except NoTranscriptFound:
log.debug(f"No transcript found for language '{lang}'")
continue
except Exception as e:
log.info(f"Error finding transcript for language '{lang}'")
raise e
# If we get here, all languages failed
languages_tried = ", ".join(self.language)
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
)

View file

@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from typing import Optional, List, Tuple
class BaseReranker(ABC):
@abstractmethod
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
pass

View file

@ -1,13 +1,21 @@
import os import os
import logging
import torch import torch
import numpy as np import numpy as np
from colbert.infra import ColBERTConfig from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
class ColBERT: from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT(BaseReranker):
def __init__(self, name, **kwargs) -> None: def __init__(self, name, **kwargs) -> None:
print("ColBERT: Loading model", name) log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
DOCKER = kwargs.get("env") == "docker" DOCKER = kwargs.get("env") == "docker"

View file

@ -0,0 +1,60 @@
import logging
import requests
from typing import Optional, List, Tuple
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalReranker(BaseReranker):
def __init__(
self,
api_key: str,
url: str = "http://localhost:8080/v1/rerank",
model: str = "reranker",
):
self.api_key = api_key
self.url = url
self.model = model
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
query = sentences[0][0]
docs = [i[1] for i in sentences]
payload = {
"model": self.model,
"query": query,
"documents": docs,
"top_n": len(docs),
}
try:
log.info(f"ExternalReranker:predict:model {self.model}")
log.info(f"ExternalReranker:predict:query {query}")
r = requests.post(
f"{self.url}",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
},
json=payload,
)
r.raise_for_status()
data = r.json()
if "results" in data:
sorted_results = sorted(data["results"], key=lambda x: x["index"])
return [result["relevance_score"] for result in sorted_results]
else:
log.error("No results found in external reranking response")
return None
except Exception as e:
log.exception(f"Error in external reranking: {e}")
return None

View file

@ -1,27 +1,36 @@
import logging import logging
import os import os
import uuid
from typing import Optional, Union from typing import Optional, Union
import asyncio
import requests import requests
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.config import VECTOR_DB from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.retrieval.vector.main import GetResult
from open_webui.env import ( from open_webui.env import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
OFFLINE_MODE, OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS, ENABLE_FORWARD_USER_INFO_HEADERS,
) )
from open_webui.config import (
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -46,7 +55,7 @@ class VectorSearchRetriever(BaseRetriever):
) -> list[Document]: ) -> list[Document]:
result = VECTOR_DB_CLIENT.search( result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors=[self.embedding_function(query)], vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
limit=self.top_k, limit=self.top_k,
) )
@ -69,6 +78,7 @@ def query_doc(
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
): ):
try: try:
log.debug(f"query_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.search( result = VECTOR_DB_CLIENT.search(
collection_name=collection_name, collection_name=collection_name,
vectors=[query_embedding], vectors=[query_embedding],
@ -80,24 +90,40 @@ def query_doc(
return result return result
except Exception as e: except Exception as e:
print(e) log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
log.debug(f"get_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
raise e raise e
def query_doc_with_hybrid_search( def query_doc_with_hybrid_search(
collection_name: str, collection_name: str,
collection_result: GetResult,
query: str, query: str,
embedding_function, embedding_function,
k: int, k: int,
reranking_function, reranking_function,
k_reranker: int,
r: float, r: float,
hybrid_bm25_weight: float,
) -> dict: ) -> dict:
try: try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name) log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
bm25_retriever = BM25Retriever.from_texts( bm25_retriever = BM25Retriever.from_texts(
texts=result.documents[0], texts=collection_result.documents[0],
metadatas=result.metadatas[0], metadatas=collection_result.metadatas[0],
) )
bm25_retriever.k = k bm25_retriever.k = k
@ -107,12 +133,23 @@ def query_doc_with_hybrid_search(
top_k=k, top_k=k,
) )
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever( ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5] retrievers=[vector_search_retriever], weights=[1.0]
) )
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever], weights=[1.0]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
)
compressor = RerankCompressor( compressor = RerankCompressor(
embedding_function=embedding_function, embedding_function=embedding_function,
top_n=k, top_n=k_reranker,
reranking_function=reranking_function, reranking_function=reranking_function,
r_score=r, r_score=r,
) )
@ -122,10 +159,23 @@ def query_doc_with_hybrid_search(
) )
result = compression_retriever.invoke(query) result = compression_retriever.invoke(query)
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
sorted_items = sorted(
zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
)
sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items))
result = { result = {
"distances": [[d.metadata.get("score") for d in result]], "distances": [distances],
"documents": [[d.page_content for d in result]], "documents": [documents],
"metadatas": [[d.metadata for d in result]], "metadatas": [metadatas],
} }
log.info( log.info(
@ -134,52 +184,88 @@ def query_doc_with_hybrid_search(
) )
return result return result
except Exception as e: except Exception as e:
log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
raise e raise e
def merge_and_sort_query_results( def merge_get_results(get_results: list[dict]) -> dict:
query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]:
# Initialize lists to store combined data # Initialize lists to store combined data
combined_distances = []
combined_documents = [] combined_documents = []
combined_metadatas = [] combined_metadatas = []
combined_ids = []
for data in query_results: for data in get_results:
combined_distances.extend(data["distances"][0])
combined_documents.extend(data["documents"][0]) combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0]) combined_metadatas.extend(data["metadatas"][0])
combined_ids.extend(data["ids"][0])
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
# Create the output dictionary # Create the output dictionary
result = { result = {
"distances": [sorted_distances], "documents": [combined_documents],
"documents": [sorted_documents], "metadatas": [combined_metadatas],
"metadatas": [sorted_metadatas], "ids": [combined_ids],
} }
return result return result
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
# Initialize lists to store combined data
combined = dict() # To store documents with unique document hashes
for data in query_results:
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.sha256(
document.encode()
).hexdigest() # Compute a hash for uniqueness
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
}
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
def query_collection( def query_collection(
collection_names: list[str], collection_names: list[str],
queries: list[str], queries: list[str],
@ -187,29 +273,49 @@ def query_collection(
k: int, k: int,
) -> dict: ) -> dict:
results = [] results = []
for query in queries: error = False
query_embedding = embedding_function(query)
for collection_name in collection_names: def process_query_collection(collection_name, query_embedding):
if collection_name:
try: try:
if collection_name:
result = query_doc( result = query_doc(
collection_name=collection_name, collection_name=collection_name,
k=k, k=k,
query_embedding=query_embedding, query_embedding=query_embedding,
) )
if result is not None: if result is not None:
results.append(result.model_dump()) return result.model_dump(), None
return None, None
except Exception as e: except Exception as e:
log.exception(f"Error when querying the collection: {e}") log.exception(f"Error when querying the collection: {e}")
else: return None, e
pass
if VECTOR_DB == "chroma": # Generate all query embeddings (in one call)
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections log.debug(
return merge_and_sort_query_results(results, k=k, reverse=False) f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
else: )
return merge_and_sort_query_results(results, k=k, reverse=True)
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
result = executor.submit(
process_query_collection, collection_name, query_embedding
)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
log.warning("All collection queries failed. No results returned.")
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search( def query_collection_with_hybrid_search(
@ -218,39 +324,74 @@ def query_collection_with_hybrid_search(
embedding_function, embedding_function,
k: int, k: int,
reranking_function, reranking_function,
k_reranker: int,
r: float, r: float,
hybrid_bm25_weight: float,
) -> dict: ) -> dict:
results = [] results = []
error = False error = False
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
collection_results = {}
for collection_name in collection_names: for collection_name in collection_names:
try: try:
for query in queries: log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}")
collection_results[collection_name] = None
log.info(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
)
def process_query(collection_name, query):
try:
result = query_doc_with_hybrid_search( result = query_doc_with_hybrid_search(
collection_name=collection_name, collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query, query=query,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function, reranking_function=reranking_function,
k_reranker=k_reranker,
r=r, r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
) )
results.append(result) return result, None
except Exception as e: except Exception as e:
log.exception( log.exception(f"Error when querying the collection with hybrid_search: {e}")
"Error when querying the collection with " f"hybrid_search: {e}" return None, e
)
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
tasks = [
(cn, q)
for cn in collection_names
if collection_results[cn] is not None
for q in queries
]
with ThreadPoolExecutor() as executor:
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True error = True
elif result is not None:
results.append(result)
if error: if error and not results:
raise Exception( raise Exception(
"Hybrid search failed for all collections. Using Non hybrid search as fallback." "Hybrid search failed for all collections. Using Non-hybrid search as fallback."
) )
if VECTOR_DB == "chroma": return merge_and_sort_query_results(results, k=k)
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k, reverse=False)
else:
return merge_and_sort_query_results(results, k=k, reverse=True)
def get_embedding_function( def get_embedding_function(
@ -260,58 +401,132 @@ def get_embedding_function(
url, url,
key, key,
embedding_batch_size, embedding_batch_size,
azure_api_version=None,
): ):
if embedding_engine == "": if embedding_engine == "":
return lambda query, user=None: embedding_function.encode(query).tolist() return lambda query, prefix=None, user=None: embedding_function.encode(
elif embedding_engine in ["ollama", "openai"]: query, **({"prompt": prefix} if prefix else {})
func = lambda query, user=None: generate_embeddings( ).tolist()
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine, engine=embedding_engine,
model=embedding_model, model=embedding_model,
text=query, text=query,
prefix=prefix,
url=url, url=url,
key=key, key=key,
user=user, user=user,
azure_api_version=azure_api_version,
) )
def generate_multiple(query, user, func): def generate_multiple(query, prefix, user, func):
if isinstance(query, list): if isinstance(query, list):
embeddings = [] embeddings = []
for i in range(0, len(query), embedding_batch_size): for i in range(0, len(query), embedding_batch_size):
embeddings.extend( embeddings.extend(
func(query[i : i + embedding_batch_size], user=user) func(
query[i : i + embedding_batch_size],
prefix=prefix,
user=user,
)
) )
return embeddings return embeddings
else: else:
return func(query, user) return func(query, prefix, user)
return lambda query, user=None: generate_multiple(query, user, func) return lambda query, prefix=None, user=None: generate_multiple(
query, prefix, user, func
)
else: else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}") raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_sources_from_files( def get_sources_from_files(
request,
files, files,
queries, queries,
embedding_function, embedding_function,
k, k,
reranking_function, reranking_function,
k_reranker,
r, r,
hybrid_bm25_weight,
hybrid_search, hybrid_search,
full_context=False,
): ):
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
for file in files: for file in files:
if file.get("context") == "full":
context = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
context = { context = {
"documents": [[file.get("file").get("data", {}).get("content")]], "documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]], "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
} }
else: elif (
context = None file.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
context = {
"documents": [documents],
"metadatas": [metadatas],
}
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
context = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
elif file.get("file").get("data"):
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [
[file.get("file").get("data", {}).get("metadata", {})]
],
}
else:
collection_names = [] collection_names = []
if file.get("type") == "collection": if file.get("type") == "collection":
if file.get("legacy"): if file.get("legacy"):
@ -331,6 +546,13 @@ def get_sources_from_files(
log.debug(f"skipping {file} as it has already been extracted") log.debug(f"skipping {file} as it has already been extracted")
continue continue
if full_context:
try:
context = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try: try:
context = None context = None
if file.get("type") == "text": if file.get("type") == "text":
@ -344,7 +566,9 @@ def get_sources_from_files(
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function, reranking_function=reranking_function,
k_reranker=k_reranker,
r=r, r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
) )
except Exception as e: except Exception as e:
log.debug( log.debug(
@ -367,6 +591,7 @@ def get_sources_from_files(
if context: if context:
if "data" in file: if "data" in file:
del file["data"] del file["data"]
relevant_contexts.append({**context, "file": file}) relevant_contexts.append({**context, "file": file})
sources = [] sources = []
@ -435,9 +660,17 @@ def generate_openai_batch_embeddings(
texts: list[str], texts: list[str],
url: str = "https://api.openai.com/v1", url: str = "https://api.openai.com/v1",
key: str = "", key: str = "",
prefix: str = None,
user: UserModel = None, user: UserModel = None,
) -> Optional[list[list[float]]]: ) -> Optional[list[list[float]]]:
try: try:
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
r = requests.post( r = requests.post(
f"{url}/embeddings", f"{url}/embeddings",
headers={ headers={
@ -454,7 +687,7 @@ def generate_openai_batch_embeddings(
else {} else {}
), ),
}, },
json={"input": texts, "model": model}, json=json_data,
) )
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
@ -463,14 +696,80 @@ def generate_openai_batch_embeddings(
else: else:
raise "Something went wrong :/" raise "Something went wrong :/"
except Exception as e: except Exception as e:
print(e) log.exception(f"Error generating openai batch embeddings: {e}")
return None
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
r = requests.post(
url,
headers={
"Content-Type": "application/json",
"api-key": key,
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None return None
def generate_ollama_batch_embeddings( def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]: ) -> Optional[list[list[float]]]:
try: try:
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
r = requests.post( r = requests.post(
f"{url}/api/embed", f"{url}/api/embed",
headers={ headers={
@ -487,7 +786,7 @@ def generate_ollama_batch_embeddings(
else {} else {}
), ),
}, },
json={"input": texts, "model": model}, json=json_data,
) )
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
@ -497,37 +796,55 @@ def generate_ollama_batch_embeddings(
else: else:
raise "Something went wrong :/" raise "Something went wrong :/"
except Exception as e: except Exception as e:
print(e) log.exception(f"Error generating ollama batch embeddings: {e}")
return None return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
url = kwargs.get("url", "") url = kwargs.get("url", "")
key = kwargs.get("key", "") key = kwargs.get("key", "")
user = kwargs.get("user") user = kwargs.get("user")
if engine == "ollama": if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list): if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings( text = [f"{prefix}{text_element}" for text_element in text]
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
)
else: else:
text = f"{prefix}{text}"
if engine == "ollama":
embeddings = generate_ollama_batch_embeddings( embeddings = generate_ollama_batch_embeddings(
**{ **{
"model": model, "model": model,
"texts": [text], "texts": text if isinstance(text, list) else [text],
"url": url, "url": url,
"key": key, "key": key,
"prefix": prefix,
"user": user, "user": user,
} }
) )
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai": elif engine == "openai":
if isinstance(text, list): embeddings = generate_openai_batch_embeddings(
embeddings = generate_openai_batch_embeddings(model, text, url, key, user) model, text if isinstance(text, list) else [text], url, key, prefix, user
else: )
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user) return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = generate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings
@ -563,13 +880,15 @@ class RerankCompressor(BaseDocumentCompressor):
else: else:
from sentence_transformers import util from sentence_transformers import util
query_embedding = self.embedding_function(query) query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
document_embedding = self.embedding_function( document_embedding = self.embedding_function(
[doc.page_content for doc in documents] [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
) )
scores = util.cos_sim(query_embedding, document_embedding)[0] scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist())) docs_with_scores = list(
zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
)
if self.r_score: if self.r_score:
docs_with_scores = [ docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score (d, s) for d, s in docs_with_scores if s >= self.r_score

View file

@ -1,22 +0,0 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

32
backend/open_webui/retrieval/vector/dbs/chroma.py Normal file → Executable file
View file

@ -1,10 +1,16 @@
import chromadb import chromadb
import logging
from chromadb import Settings from chromadb import Settings
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import ( from open_webui.config import (
CHROMA_DATA_PATH, CHROMA_DATA_PATH,
CHROMA_HTTP_HOST, CHROMA_HTTP_HOST,
@ -16,9 +22,13 @@ from open_webui.config import (
CHROMA_CLIENT_AUTH_PROVIDER, CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS, CHROMA_CLIENT_AUTH_CREDENTIALS,
) )
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient: class ChromaClient(VectorDBBase):
def __init__(self): def __init__(self):
settings_dict = { settings_dict = {
"allow_reset": True, "allow_reset": True,
@ -70,10 +80,16 @@ class ChromaClient:
n_results=limit, n_results=limit,
) )
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation
distances: list = result["distances"][0]
distances = [2 - dist for dist in distances]
distances = [[dist / 2 for dist in distances]]
return SearchResult( return SearchResult(
**{ **{
"ids": result["ids"], "ids": result["ids"],
"distances": result["distances"], "distances": distances,
"documents": result["documents"], "documents": result["documents"],
"metadatas": result["metadatas"], "metadatas": result["metadatas"],
} }
@ -102,8 +118,7 @@ class ChromaClient:
} }
) )
return None return None
except Exception as e: except:
print(e)
return None return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
@ -162,12 +177,19 @@ class ChromaClient:
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
# Delete the items from the collection based on the ids. # Delete the items from the collection based on the ids.
try:
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
if ids: if ids:
collection.delete(ids=ids) collection.delete(ids=ids)
elif filter: elif filter:
collection.delete(where=filter) collection.delete(where=filter)
except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete
log.debug(
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
)
pass
def reset(self): def reset(self):
# Resets the database. This will delete all collections and item entries. # Resets the database. This will delete all collections and item entries.

View file

@ -0,0 +1,300 @@
from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
ELASTICSEARCH_API_KEY,
ELASTICSEARCH_USERNAME,
ELASTICSEARCH_PASSWORD,
ELASTICSEARCH_CLOUD_ID,
ELASTICSEARCH_INDEX_PREFIX,
SSL_ASSERT_FINGERPRINT,
)
class ElasticsearchClient(VectorDBBase):
"""
Important:
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
an index for each file but store it as a text field, while seperating to different index
baesd on the embedding length.
"""
def __init__(self):
self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
self.client = Elasticsearch(
hosts=[ELASTICSEARCH_URL],
ca_certs=ELASTICSEARCH_CA_CERTS,
api_key=ELASTICSEARCH_API_KEY,
cloud_id=ELASTICSEARCH_CLOUD_ID,
basic_auth=(
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
else None
),
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
)
# Status: works
def _get_index_name(self, dimension: int) -> str:
return f"{self.index_prefix}_d{str(dimension)}"
# Status: works
def _scan_result_to_get_result(self, result) -> GetResult:
if not result:
return None
ids = []
documents = []
metadatas = []
for hit in result:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=[ids],
distances=[distances],
documents=[documents],
metadatas=[metadatas],
)
# Status: works
def _create_index(self, dimension: int):
body = {
"mappings": {
"dynamic_templates": [
{
"strings": {
"match_mapping_type": "string",
"mapping": {"type": "keyword"},
}
}
],
"properties": {
"collection": {"type": "keyword"},
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "cosine",
},
"text": {"type": "text"},
"metadata": {"type": "object"},
},
}
}
self.client.indices.create(index=self._get_index_name(dimension), body=body)
# Status: works
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : min(i + batch_size, len(items))]
# Status: works
def has_collection(self, collection_name) -> bool:
query_body = {"query": {"bool": {"filter": []}}}
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
try:
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
return result.body["count"] > 0
except Exception as e:
return None
def delete_collection(self, collection_name: str):
query = {"query": {"term": {"collection": collection_name}}}
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
# Status: works
def search(
self, collection_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {
"bool": {"filter": [{"term": {"collection": collection_name}}]}
},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=self._get_index_name(len(vectors[0])), body=query
)
return self._result_to_search_result(result)
# Status: only tested halfwat
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}*",
body=query_body,
size=size,
)
return self._result_to_get_result(result)
except Exception as e:
return None
# Status: works
def _has_index(self, dimension: int):
return self.client.indices.exists(
index=self._get_index_name(dimension=dimension)
)
def get_or_create_index(self, dimension: int):
if not self._has_index(dimension=dimension):
self._create_index(dimension=dimension)
# Status: works
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
"_source": ["text", "metadata"],
}
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
return self._scan_result_to_get_result(results)
# Status: works
def insert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
bulk(self.client, actions)
# Upsert documents using the update API with doc_as_upsert=True.
def upsert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"_op_type": "update",
"_index": self._get_index_name(dimension=len(item["vector"])),
"_id": item["id"],
"doc": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
"doc_as_upsert": True,
}
for item in batch
]
bulk(self.client, actions)
# Delete specific documents from a collection by filtering on both collection and document IDs.
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
}
# logic based on chromaDB
if ids:
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
elif filter:
for field, value in filter.items():
query["query"]["bool"]["filter"].append(
{"term": {f"metadata.{field}": value}}
)
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}*")
for index in indices:
self.client.indices.delete(index=index)

View file

@ -1,30 +1,42 @@
from pymilvus import MilvusClient as Client from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType from pymilvus import FieldSchema, DataType
import json import json
import logging
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.main import (
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import ( from open_webui.config import (
MILVUS_URI, MILVUS_URI,
MILVUS_DB, MILVUS_DB,
MILVUS_TOKEN, MILVUS_TOKEN,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
) )
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient: class MilvusClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = "open_webui" self.collection_prefix = "open_webui"
if MILVUS_TOKEN is None: if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB) self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
else: else:
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB, token=MILVUS_TOKEN) self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
def _result_to_get_result(self, result) -> GetResult: def _result_to_get_result(self, result) -> GetResult:
ids = [] ids = []
documents = [] documents = []
metadatas = [] metadatas = []
for match in result: for match in result:
_ids = [] _ids = []
_documents = [] _documents = []
@ -33,11 +45,9 @@ class MilvusClient:
_ids.append(item.get("id")) _ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text")) _documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata")) _metadatas.append(item.get("metadata"))
ids.append(_ids) ids.append(_ids)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return GetResult( return GetResult(
**{ **{
"ids": ids, "ids": ids,
@ -51,24 +61,23 @@ class MilvusClient:
distances = [] distances = []
documents = [] documents = []
metadatas = [] metadatas = []
for match in result: for match in result:
_ids = [] _ids = []
_distances = [] _distances = []
_documents = [] _documents = []
_metadatas = [] _metadatas = []
for item in match: for item in match:
_ids.append(item.get("id")) _ids.append(item.get("id"))
_distances.append(item.get("distance")) # normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
_dist = (item.get("distance") + 1.0) / 2.0
_distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text")) _documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata")) _metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids) ids.append(_ids)
distances.append(_distances) distances.append(_distances)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return SearchResult( return SearchResult(
**{ **{
"ids": ids, "ids": ids,
@ -101,11 +110,39 @@ class MilvusClient:
) )
index_params = self.client.prepare_index_params() index_params = self.client.prepare_index_params()
# Use configurations from config.py
index_type = MILVUS_INDEX_TYPE.upper()
metric_type = MILVUS_METRIC_TYPE.upper()
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
index_creation_params = {}
if index_type == "HNSW":
index_creation_params = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
}
log.info(f"HNSW params: {index_creation_params}")
elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.")
else:
log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
)
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
index_params.add_index( index_params.add_index(
field_name="vector", field_name="vector",
index_type="HNSW", index_type=index_type,
metric_type="COSINE", metric_type=metric_type,
params={"M": 16, "efConstruction": 100}, params=index_creation_params,
) )
self.client.create_collection( self.client.create_collection(
@ -113,6 +150,9 @@ class MilvusClient:
schema=schema, schema=schema,
index_params=index_params, index_params=index_params,
) )
log.info(
f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
)
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name. # Check if the collection exists based on the collection name.
@ -133,82 +173,113 @@ class MilvusClient:
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
result = self.client.search( result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors, data=vectors,
limit=limit, limit=limit,
output_fields=["data", "metadata"], output_fields=["data", "metadata"],
# search_params=search_params # Potentially add later if needed
) )
return self._result_to_search_result(result) return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying # Construct the filter string for querying
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name): if not self.has_collection(collection_name):
log.warning(
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None return None
filter_string = " && ".join( filter_string = " && ".join(
[ [
f'metadata["{key}"] == {json.dumps(value)}' f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
max_limit = 16383 # The maximum number of records per request max_limit = 16383 # The maximum number of records per request
all_results = [] all_results = []
if limit is None: if limit is None:
limit = float("inf") # Use infinity as a placeholder for no limit # Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = (
16384 * 10
) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log.info(
f"Limit not specified for query, fetching up to {limit} results in batches."
)
# Initialize offset and remaining to handle pagination # Initialize offset and remaining to handle pagination
offset = 0 offset = 0
remaining = limit remaining = limit
try: try:
log.info(
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
)
# Loop until there are no more items to fetch or the desired limit is reached # Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0: while remaining > 0:
print("remaining", remaining)
current_fetch = min( current_fetch = min(
max_limit, remaining max_limit, remaining if isinstance(remaining, int) else max_limit
) # Determine how many items to fetch in this iteration )
log.debug(
f"Querying with offset: {offset}, current_fetch: {current_fetch}"
)
results = self.client.query( results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string, filter=filter_string,
output_fields=["*"], output_fields=[
"id",
"data",
"metadata",
], # Explicitly list needed fields. Vector not usually needed in query.
limit=current_fetch, limit=current_fetch,
offset=offset, offset=offset,
) )
if not results: if not results:
log.debug("No more results from query.")
break break
all_results.extend(results) all_results.extend(results)
results_count = len(results) results_count = len(results)
remaining -= ( log.debug(f"Fetched {results_count} results in this batch.")
results_count # Decrease remaining by the number of items fetched
) if isinstance(remaining, int):
remaining -= results_count
offset += results_count offset += results_count
# Break the loop if the results returned are less than the requested fetch count # Break the loop if the results returned are less than the requested fetch count (means end of data)
if results_count < current_fetch: if results_count < current_fetch:
log.debug(
"Fetched less than requested, assuming end of results for this query."
)
break break
print(all_results) log.info(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results]) return self._result_to_get_result([all_results])
except Exception as e: except Exception as e:
print(e) log.exception(
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
)
return None return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection. This can be very resource-intensive for large collections.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
result = self.client.query( log.warning(
collection_name=f"{self.collection_prefix}_{collection_name}", f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
filter='id != ""',
) )
return self._result_to_get_result([result]) # Using query with a trivial filter to get all items.
# This will use the paginated query logic.
return self.query(collection_name=collection_name, filter={}, limit=None)
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
@ -216,10 +287,23 @@ class MilvusClient:
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
)
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
)
raise ValueError(
"Cannot create Milvus collection without items to determine vector dimension."
)
self._create_collection( self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"]) collection_name=collection_name, dimension=len(items[0]["vector"])
) )
log.info(
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
return self.client.insert( return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=[ data=[
@ -239,10 +323,23 @@ class MilvusClient:
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
)
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
)
raise ValueError(
"Cannot create Milvus collection for upsert without items to determine vector dimension."
)
self._create_collection( self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"]) collection_name=collection_name, dimension=len(items[0]["vector"])
) )
log.info(
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
return self.client.upsert( return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=[ data=[
@ -262,30 +359,55 @@ class MilvusClient:
ids: Optional[list[str]] = None, ids: Optional[list[str]] = None,
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
# Delete the items from the collection based on the ids. # Delete the items from the collection based on the ids or filter.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
log.warning(
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None
if ids: if ids:
log.info(
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
)
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids, ids=ids,
) )
elif filter: elif filter:
# Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join( filter_string = " && ".join(
[ [
f'metadata["{key}"] == {json.dumps(value)}' f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
log.info(
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
)
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string, filter=filter_string,
) )
else:
log.warning(
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
)
return None
def reset(self): def reset(self):
# Resets the database. This will delete all collections and item entries. # Resets the database. This will delete all collections and item entries that match the prefix.
log.warning(
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
)
collection_names = self.client.list_collections() collection_names = self.client.list_collections()
for collection_name in collection_names: deleted_collections = []
if collection_name.startswith(self.collection_prefix): for collection_name_full in collection_names:
self.client.drop_collection(collection_name=collection_name) if collection_name_full.startswith(self.collection_prefix):
try:
self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}")
except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}")
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")

View file

@ -1,7 +1,13 @@
from opensearchpy import OpenSearch from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import ( from open_webui.config import (
OPENSEARCH_URI, OPENSEARCH_URI,
OPENSEARCH_SSL, OPENSEARCH_SSL,
@ -11,7 +17,7 @@ from open_webui.config import (
) )
class OpenSearchClient: class OpenSearchClient(VectorDBBase):
def __init__(self): def __init__(self):
self.index_prefix = "open_webui" self.index_prefix = "open_webui"
self.client = OpenSearch( self.client = OpenSearch(
@ -21,7 +27,13 @@ class OpenSearchClient:
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
) )
def _get_index_name(self, collection_name: str) -> str:
return f"{self.index_prefix}_{collection_name}"
def _result_to_get_result(self, result) -> GetResult: def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
ids = [] ids = []
documents = [] documents = []
metadatas = [] metadatas = []
@ -31,9 +43,12 @@ class OpenSearchClient:
documents.append(hit["_source"].get("text")) documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata")) metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _result_to_search_result(self, result) -> SearchResult: def _result_to_search_result(self, result) -> SearchResult:
if not result["hits"]["hits"]:
return None
ids = [] ids = []
distances = [] distances = []
documents = [] documents = []
@ -46,51 +61,63 @@ class OpenSearchClient:
metadatas.append(hit["_source"].get("metadata")) metadatas.append(hit["_source"].get("metadata"))
return SearchResult( return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas ids=[ids],
distances=[distances],
documents=[documents],
metadatas=[metadatas],
) )
def _create_index(self, index_name: str, dimension: int): def _create_index(self, collection_name: str, dimension: int):
body = { body = {
"settings": {"index": {"knn": True}},
"mappings": { "mappings": {
"properties": { "properties": {
"id": {"type": "keyword"}, "id": {"type": "keyword"},
"vector": { "vector": {
"type": "dense_vector", "type": "knn_vector",
"dims": dimension, # Adjust based on your vector dimensions "dimension": dimension, # Adjust based on your vector dimensions
"index": true, "index": True,
"similarity": "faiss", "similarity": "faiss",
"method": { "method": {
"name": "hnsw", "name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity "space_type": "innerproduct", # Use inner product to approximate cosine similarity
"engine": "faiss", "engine": "faiss",
"parameters": {
"ef_construction": 128, "ef_construction": 128,
"m": 16, "m": 16,
}, },
}, },
},
"text": {"type": "text"}, "text": {"type": "text"},
"metadata": {"type": "object"}, "metadata": {"type": "object"},
} }
},
} }
} self.client.indices.create(
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body) index=self._get_index_name(collection_name), body=body
)
def _create_batches(self, items: list[VectorItem], batch_size=100): def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size): for i in range(0, len(items), batch_size):
yield items[i : i + batch_size] yield items[i : i + batch_size]
def has_collection(self, index_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
# has_collection here means has index. # has_collection here means has index.
# We are simply adapting to the norms of the other DBs. # We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}") return self.client.indices.exists(index=self._get_index_name(collection_name))
def delete_colleciton(self, index_name: str): def delete_collection(self, collection_name: str):
# delete_collection here means delete index. # delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs. # We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}") self.client.indices.delete(index=self._get_index_name(collection_name))
def search( def search(
self, index_name: str, vectors: list[list[float]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
try:
if not self.has_collection(collection_name):
return None
query = { query = {
"size": limit, "size": limit,
"_source": ["text", "metadata"], "_source": ["text", "metadata"],
@ -98,9 +125,10 @@ class OpenSearchClient:
"script_score": { "script_score": {
"query": {"match_all": {}}, "query": {"match_all": {}},
"script": { "script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0", "source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
"params": { "params": {
"vector": vectors[0] "field": "vector",
"query_value": vectors[0],
}, # Assuming single query vector }, # Assuming single query vector
}, },
} }
@ -108,11 +136,14 @@ class OpenSearchClient:
} }
result = self.client.search( result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query index=self._get_index_name(collection_name), body=query
) )
return self._result_to_search_result(result) return self._result_to_search_result(result)
except Exception as e:
return None
def query( def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]: ) -> Optional[GetResult]:
@ -125,13 +156,15 @@ class OpenSearchClient:
} }
for field, value in filter.items(): for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}}) query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}}
)
size = limit if limit else 10 size = limit if limit else 10
try: try:
result = self.client.search( result = self.client.search(
index=f"{self.index_prefix}_{collection_name}", index=self._get_index_name(collection_name),
body=query_body, body=query_body,
size=size, size=size,
) )
@ -141,26 +174,28 @@ class OpenSearchClient:
except Exception as e: except Exception as e:
return None return None
def get_or_create_index(self, index_name: str, dimension: int): def _create_index_if_not_exists(self, collection_name: str, dimension: int):
if not self.has_index(index_name): if not self.has_collection(collection_name):
self._create_index(index_name, dimension) self._create_index(collection_name, dimension)
def get(self, index_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]} query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search( result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query index=self._get_index_name(collection_name), body=query
) )
return self._result_to_get_result(result) return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
if not self.has_index(index_name): self._create_index_if_not_exists(
self._create_index(index_name, dimension=len(items[0]["vector"])) collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items): for batch in self._create_batches(items):
actions = [ actions = [
{ {
"index": { "_op_type": "index",
"_index": self._get_index_name(collection_name),
"_id": item["id"], "_id": item["id"],
"_source": { "_source": {
"vector": item["vector"], "vector": item["vector"],
@ -168,37 +203,59 @@ class OpenSearchClient:
"metadata": item["metadata"], "metadata": item["metadata"],
}, },
} }
}
for item in batch for item in batch
] ]
self.client.bulk(actions) bulk(self.client, actions)
def upsert(self, index_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
if not self.has_index(index_name): self._create_index_if_not_exists(
self._create_index(index_name, dimension=len(items[0]["vector"])) collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items): for batch in self._create_batches(items):
actions = [ actions = [
{ {
"index": { "_op_type": "update",
"_index": self._get_index_name(collection_name),
"_id": item["id"], "_id": item["id"],
"_source": { "doc": {
"vector": item["vector"], "vector": item["vector"],
"text": item["text"], "text": item["text"],
"metadata": item["metadata"], "metadata": item["metadata"],
}, },
} "doc_as_upsert": True,
} }
for item in batch for item in batch
] ]
self.client.bulk(actions) bulk(self.client, actions)
def delete(self, index_name: str, ids: list[str]): def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
if ids:
actions = [ actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} {
"_op_type": "delete",
"_index": self._get_index_name(collection_name),
"_id": id,
}
for id in ids for id in ids
] ]
self.client.bulk(body=actions) bulk(self.client, actions)
elif filter:
query_body = {
"query": {"bool": {"filter": []}},
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
def reset(self): def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*") indices = self.client.indices.get(index=f"{self.index_prefix}_*")

View file

@ -1,4 +1,5 @@
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
import logging
from sqlalchemy import ( from sqlalchemy import (
cast, cast,
column, column,
@ -21,12 +22,22 @@ from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base() Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base): class DocumentChunk(Base):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
@ -38,7 +49,7 @@ class DocumentChunk(Base):
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient: class PgvectorClient(VectorDBBase):
def __init__(self) -> None: def __init__(self) -> None:
# if no pgvector uri, use the existing database connection # if no pgvector uri, use the existing database connection
@ -82,10 +93,10 @@ class PgvectorClient:
) )
) )
self.session.commit() self.session.commit()
print("Initialization complete.") log.info("Initialization complete.")
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
print(f"Error during initialization: {e}") log.exception(f"Error during initialization: {e}")
raise raise
def check_vector_length(self) -> None: def check_vector_length(self) -> None:
@ -130,9 +141,8 @@ class PgvectorClient:
# Pad the vector with zeros # Pad the vector with zeros
vector += [0.0] * (VECTOR_LENGTH - current_length) vector += [0.0] * (VECTOR_LENGTH - current_length)
elif current_length > VECTOR_LENGTH: elif current_length > VECTOR_LENGTH:
raise Exception( # Truncate the vector to VECTOR_LENGTH
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}" vector = vector[:VECTOR_LENGTH]
)
return vector return vector
def insert(self, collection_name: str, items: List[VectorItem]) -> None: def insert(self, collection_name: str, items: List[VectorItem]) -> None:
@ -150,12 +160,12 @@ class PgvectorClient:
new_items.append(new_chunk) new_items.append(new_chunk)
self.session.bulk_save_objects(new_items) self.session.bulk_save_objects(new_items)
self.session.commit() self.session.commit()
print( log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'." f"Inserted {len(new_items)} items into collection '{collection_name}'."
) )
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
print(f"Error during insert: {e}") log.exception(f"Error during insert: {e}")
raise raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None: def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@ -184,10 +194,12 @@ class PgvectorClient:
) )
self.session.add(new_chunk) self.session.add(new_chunk)
self.session.commit() self.session.commit()
print(f"Upserted {len(items)} items into collection '{collection_name}'.") log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
print(f"Error during upsert: {e}") log.exception(f"Error during upsert: {e}")
raise raise
def search( def search(
@ -270,7 +282,9 @@ class PgvectorClient:
for row in results: for row in results:
qid = int(row.qid) qid = int(row.qid)
ids[qid].append(row.id) ids[qid].append(row.id)
distances[qid].append(row.distance) # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
# https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
distances[qid].append((2.0 - row.distance) / 2.0)
documents[qid].append(row.text) documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata) metadatas[qid].append(row.vmetadata)
@ -278,7 +292,7 @@ class PgvectorClient:
ids=ids, distances=distances, documents=documents, metadatas=metadatas ids=ids, distances=distances, documents=documents, metadatas=metadatas
) )
except Exception as e: except Exception as e:
print(f"Error during search: {e}") log.exception(f"Error during search: {e}")
return None return None
def query( def query(
@ -310,7 +324,7 @@ class PgvectorClient:
metadatas=metadatas, metadatas=metadatas,
) )
except Exception as e: except Exception as e:
print(f"Error during query: {e}") log.exception(f"Error during query: {e}")
return None return None
def get( def get(
@ -334,7 +348,7 @@ class PgvectorClient:
return GetResult(ids=ids, documents=documents, metadatas=metadatas) return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e: except Exception as e:
print(f"Error during get: {e}") log.exception(f"Error during get: {e}")
return None return None
def delete( def delete(
@ -356,22 +370,22 @@ class PgvectorClient:
) )
deleted = query.delete(synchronize_session=False) deleted = query.delete(synchronize_session=False)
self.session.commit() self.session.commit()
print(f"Deleted {deleted} items from collection '{collection_name}'.") log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
print(f"Error during delete: {e}") log.exception(f"Error during delete: {e}")
raise raise
def reset(self) -> None: def reset(self) -> None:
try: try:
deleted = self.session.query(DocumentChunk).delete() deleted = self.session.query(DocumentChunk).delete()
self.session.commit() self.session.commit()
print( log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table." f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
) )
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
print(f"Error during reset: {e}") log.exception(f"Error during reset: {e}")
raise raise
def close(self) -> None: def close(self) -> None:
@ -387,9 +401,9 @@ class PgvectorClient:
) )
return exists return exists
except Exception as e: except Exception as e:
print(f"Error checking collection existence: {e}") log.exception(f"Error checking collection existence: {e}")
return False return False
def delete_collection(self, collection_name: str) -> None: def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name) self.delete(collection_name)
print(f"Collection '{collection_name}' deleted.") log.info(f"Collection '{collection_name}' deleted.")

View file

@ -0,0 +1,508 @@
from typing import Optional, List, Dict, Any, Union
import logging
import time # for measuring elapsed time
from pinecone import Pinecone, ServerlessSpec
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
PINECONE_API_KEY,
PINECONE_ENVIRONMENT,
PINECONE_INDEX_NAME,
PINECONE_DIMENSION,
PINECONE_METRIC,
PINECONE_CLOUD,
)
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class PineconeClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
# Validate required configuration
self._validate_config()
# Store configuration values
self.api_key = PINECONE_API_KEY
self.environment = PINECONE_ENVIRONMENT
self.index_name = PINECONE_INDEX_NAME
self.dimension = PINECONE_DIMENSION
self.metric = PINECONE_METRIC
self.cloud = PINECONE_CLOUD
# Initialize Pinecone client for improved performance
self.client = Pinecone(api_key=self.api_key)
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
# Create index if it doesn't exist
self._initialize_index()
def _validate_config(self) -> None:
"""Validate that all required configuration variables are set."""
missing_vars = []
if not PINECONE_API_KEY:
missing_vars.append("PINECONE_API_KEY")
if not PINECONE_ENVIRONMENT:
missing_vars.append("PINECONE_ENVIRONMENT")
if not PINECONE_INDEX_NAME:
missing_vars.append("PINECONE_INDEX_NAME")
if not PINECONE_DIMENSION:
missing_vars.append("PINECONE_DIMENSION")
if not PINECONE_CLOUD:
missing_vars.append("PINECONE_CLOUD")
if missing_vars:
raise ValueError(
f"Required configuration missing: {', '.join(missing_vars)}"
)
def _initialize_index(self) -> None:
"""Initialize the Pinecone index."""
try:
# Check if index exists
if self.index_name not in self.client.list_indexes().names():
log.info(f"Creating Pinecone index '{self.index_name}'...")
self.client.create_index(
name=self.index_name,
dimension=self.dimension,
metric=self.metric,
spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
)
log.info(f"Successfully created Pinecone index '{self.index_name}'")
else:
log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index
self.index = self.client.Index(self.index_name)
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
"""Convert VectorItem objects to Pinecone point format."""
points = []
for item in items:
# Start with any existing metadata or an empty dict
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
# Add text to metadata if available
if "text" in item:
metadata["text"] = item["text"]
# Always add collection_name to metadata for filtering
metadata["collection_name"] = collection_name_with_prefix
point = {
"id": item["id"],
"values": item["vector"],
"metadata": metadata,
}
points.append(point)
return points
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
"""Get the collection name with prefix."""
return f"{self.collection_prefix}_{collection_name}"
def _normalize_distance(self, score: float) -> float:
"""Normalize distance score based on the metric used."""
if self.metric.lower() == "cosine":
# Cosine similarity ranges from -1 to 1, normalize to 0 to 1
return (score + 1.0) / 2.0
elif self.metric.lower() in ["euclidean", "dotproduct"]:
# These are already suitable for ranking (smaller is better for Euclidean)
return score
else:
# For other metrics, use as is
return score
def _result_to_get_result(self, matches: list) -> GetResult:
"""Convert Pinecone matches to GetResult format."""
ids = []
documents = []
metadatas = []
for match in matches:
metadata = getattr(match, "metadata", {}) or {}
ids.append(match.id if hasattr(match, "id") else match["id"])
documents.append(metadata.get("text", ""))
metadatas.append(metadata)
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def has_collection(self, collection_name: str) -> bool:
"""Check if a collection exists by searching for at least one item."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
# Search for at least 1 item with this collection name in metadata
response = self.index.query(
vector=[0.0] * self.dimension, # dummy vector
top_k=1,
filter={"collection_name": collection_name_with_prefix},
include_metadata=False,
)
matches = getattr(response, "matches", []) or []
return len(matches) > 0
except Exception as e:
log.exception(
f"Error checking collection '{collection_name_with_prefix}': {e}"
)
return False
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection by removing all vectors with the collection name in metadata."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
self.index.delete(filter={"collection_name": collection_name_with_prefix})
log.info(
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
)
except Exception as e:
log.warning(
f"Failed to delete collection '{collection_name_with_prefix}': {e}"
)
raise
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert vectors into a collection."""
if not items:
log.warning("No items to insert")
return
start_time = time.time()
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Parallelize batch inserts for performance
executor = self._executor
futures = []
for i in range(0, len(points), BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
futures.append(executor.submit(self.index.upsert, vectors=batch))
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as e:
log.error(f"Error inserting batch: {e}")
raise
elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Upsert (insert or update) vectors into a collection."""
if not items:
log.warning("No items to upsert")
return
start_time = time.time()
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Parallelize batch upserts for performance
executor = self._executor
futures = []
for i in range(0, len(points), BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
futures.append(executor.submit(self.index.upsert, vectors=batch))
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as e:
log.error(f"Error upserting batch: {e}")
raise
elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
"""Async version of insert using asyncio and run_in_executor for improved performance."""
if not items:
log.warning("No items to insert")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Create batches
batches = [
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
]
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
None, functools.partial(self.index.upsert, vectors=batch)
)
for batch in batches
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
log.error(f"Error in async insert batch: {result}")
raise result
log.info(
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
"""Async version of upsert using asyncio and run_in_executor for improved performance."""
if not items:
log.warning("No items to upsert")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Create batches
batches = [
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
]
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
None, functools.partial(self.index.upsert, vectors=batch)
)
for batch in batches
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
log.error(f"Error in async upsert batch: {result}")
raise result
log.info(
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
if not vectors or not vectors[0]:
log.warning("No vectors provided for search")
return None
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
if limit is None or limit <= 0:
limit = NO_LIMIT
try:
# Search using the first vector (assuming this is the intended behavior)
query_vector = vectors[0]
# Perform the search
query_response = self.index.query(
vector=query_vector,
top_k=limit,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
)
matches = getattr(query_response, "matches", []) or []
if not matches:
# Return empty result if no matches
return SearchResult(
ids=[[]],
documents=[[]],
metadatas=[[]],
distances=[[]],
)
# Convert to GetResult format
get_result = self._result_to_get_result(matches)
# Calculate normalized distances based on metric
distances = [
[
self._normalize_distance(getattr(match, "score", 0.0))
for match in matches
]
]
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=distances,
)
except Exception as e:
log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
return None
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""Query vectors by metadata filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
if limit is None or limit <= 0:
limit = NO_LIMIT
try:
# Create a zero vector for the dimension as Pinecone requires a vector
zero_vector = [0.0] * self.dimension
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
# Perform metadata-only query
query_response = self.index.query(
vector=zero_vector,
filter=pinecone_filter,
top_k=limit,
include_metadata=True,
)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""Get all vectors in a collection."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
# Use a zero vector for fetching all entries
zero_vector = [0.0] * self.dimension
# Add filter to only get vectors for this collection
query_response = self.index.query(
vector=zero_vector,
top_k=NO_LIMIT,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error getting collection '{collection_name}': {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by IDs or filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
if ids:
# Delete by IDs (in batches for large deletions)
for i in range(0, len(ids), BATCH_SIZE):
batch_ids = ids[i : i + BATCH_SIZE]
# Note: When deleting by ID, we can't filter by collection_name
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
)
log.info(
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
)
elif filter:
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
# Delete by metadata filter
self.index.delete(filter=pinecone_filter)
log.info(
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
)
else:
log.warning("No ids or filter provided for delete operation")
except Exception as e:
log.error(f"Error deleting from collection '{collection_name}': {e}")
raise
def reset(self) -> None:
"""Reset the database by deleting all collections."""
try:
self.index.delete(delete_all=True)
log.info("All vectors successfully deleted from the index.")
except Exception as e:
log.error(f"Failed to reset Pinecone index: {e}")
raise
def close(self):
"""Shut down resources."""
try:
# The new Pinecone client doesn't need explicit closing
pass
except Exception as e:
log.warning(f"Failed to clean up Pinecone resources: {e}")
self._executor.shutdown(wait=True)
def __enter__(self):
"""Enter context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager, ensuring resources are cleaned up."""
self.close()

View file

@ -1,25 +1,60 @@
from typing import Optional from typing import Optional
import logging
from urllib.parse import urlparse
from qdrant_client import QdrantClient as Qclient from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct from qdrant_client.http.models import PointStruct
from qdrant_client.models import models from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.retrieval.vector.main import (
from open_webui.config import QDRANT_URI, QDRANT_API_KEY VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
QDRANT_URI,
QDRANT_API_KEY,
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
)
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999 NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
class QdrantClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = "open-webui" self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY self.QDRANT_API_KEY = QDRANT_API_KEY
self.client = ( self.QDRANT_ON_DISK = QDRANT_ON_DISK
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) self.PREFER_GRPC = QDRANT_PREFER_GRPC
if self.QDRANT_URI self.GRPC_PORT = QDRANT_GRPC_PORT
else None
if not self.QDRANT_URI:
self.client = None
return
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC:
self.client = Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
) )
else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
def _result_to_get_result(self, points) -> GetResult: def _result_to_get_result(self, points) -> GetResult:
ids = [] ids = []
@ -45,11 +80,13 @@ class QdrantClient:
self.client.create_collection( self.client.create_collection(
collection_name=collection_name_with_prefix, collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
), ),
) )
print(f"collection {collection_name_with_prefix} successfully created!") log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension): def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name): if not self.has_collection(collection_name=collection_name):
@ -94,7 +131,8 @@ class QdrantClient:
ids=get_result.ids, ids=get_result.ids,
documents=get_result.documents, documents=get_result.documents,
metadatas=get_result.metadatas, metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]], # qdrant distance is [-1, 1], normalize to [0, 1]
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
) )
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
@ -120,7 +158,7 @@ class QdrantClient:
) )
return self._result_to_get_result(points.points) return self._result_to_get_result(points.points)
except Exception as e: except Exception as e:
print(e) log.exception(f"Error querying a collection '{collection_name}': {e}")
return None return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:

View file

@ -0,0 +1,712 @@
import logging
from typing import Optional, Tuple
from urllib.parse import urlparse
import grpc
from open_webui.config import (
QDRANT_API_KEY,
QDRANT_GRPC_PORT,
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI:
self.client = None
return
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC:
self.client = Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
)
else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
"""
Maps the traditional collection name to multi-tenant collection and tenant ID.
Returns:
tuple: (collection_name, tenant_id)
"""
# Check for user memory collections
tenant_id = collection_name
if collection_name.startswith("user-memory-"):
return self.MEMORY_COLLECTION, tenant_id
# Check for file collections
elif collection_name.startswith("file-"):
return self.FILE_COLLECTION, tenant_id
# Check for web search collections
elif collection_name.startswith("web-search-"):
return self.WEB_SEARCH_COLLECTION, tenant_id
# Handle hash-based collections (YouTube and web URLs)
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
return self.HASH_BASED_COLLECTION, tenant_id
else:
return self.KNOWLEDGE_COLLECTION, tenant_id
def _extract_error_message(self, exception):
"""
Extract error message from either HTTP or gRPC exceptions
Returns:
tuple: (status_code, error_message)
"""
# Check if it's an HTTP exception
if isinstance(exception, UnexpectedResponse):
try:
error_data = exception.structured()
error_msg = error_data.get("status", {}).get("error", "")
return exception.status_code, error_msg
except Exception as inner_e:
log.error(f"Failed to parse HTTP error: {inner_e}")
return exception.status_code, str(exception)
# Check if it's a gRPC exception
elif isinstance(exception, grpc.RpcError):
# Extract status code from gRPC error
status_code = None
if hasattr(exception, "code") and callable(exception.code):
status_code = exception.code().value[0]
# Extract error message
error_msg = str(exception)
if "details =" in error_msg:
# Parse the details line which contains the actual error message
try:
details_line = [
line.strip()
for line in error_msg.split("\n")
if "details =" in line
][0]
error_msg = details_line.split("details =")[1].strip(' "')
except (IndexError, AttributeError):
# Fall back to full message if parsing fails
pass
return status_code, error_msg
# For any other type of exception
return None, str(exception)
def _is_collection_not_found_error(self, exception):
"""
Check if the exception is due to collection not found, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# HTTP error (404)
if (
status_code == 404
and "Collection" in error_msg
and "doesn't exist" in error_msg
):
return True
# gRPC error (NOT_FOUND status)
if (
isinstance(exception, grpc.RpcError)
and exception.code() == grpc.StatusCode.NOT_FOUND
):
return True
return False
def _is_dimension_mismatch_error(self, exception):
"""
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# Common patterns in both HTTP and gRPC
return (
"Vector dimension error" in error_msg
or "dimensions mismatch" in error_msg
or "invalid vector size" in error_msg
)
def _create_multi_tenant_collection_if_not_exists(
self, mt_collection_name: str, dimension: int = 384
):
"""
Creates a collection with multi-tenancy configuration if it doesn't exist.
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
"""
try:
# Try to create the collection directly - will fail if it already exists
self.client.create_collection(
collection_name=mt_collection_name,
vectors_config=models.VectorParams(
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
hnsw_config=models.HnswConfigDiff(
payload_m=16, # Enable per-tenant indexing
m=0,
on_disk=self.QDRANT_ON_DISK,
),
)
# Create tenant ID payload index
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name="tenant_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
wait=True,
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
except (UnexpectedResponse, grpc.RpcError) as e:
# Check for the specific error indicating collection already exists
status_code, error_msg = self._extract_error_message(e)
# HTTP status code 409 or gRPC ALREADY_EXISTS
if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
isinstance(e, grpc.RpcError)
and e.code() == grpc.StatusCode.ALREADY_EXISTS
):
if "already exists" in error_msg:
log.debug(f"Collection {mt_collection_name} already exists")
return
# If it's not an already exists error, re-raise
raise e
except Exception as e:
raise e
def _create_points(self, items: list[VectorItem], tenant_id: str):
"""
Create point structs from vector items with tenant ID.
"""
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={
"text": item["text"],
"metadata": item["metadata"],
"tenant_id": tenant_id,
},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
"""
Check if a logical collection exists by checking for any points with the tenant ID.
"""
if not self.client:
return False
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try directly querying - most of the time collection should exist
response = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=1,
)
# Collection exists with this tenant ID if there are points
return len(response.points) > 0
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist")
return False
else:
# For other API errors, log and return False
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
return False
except Exception as e:
# For any other errors, log and return False
log.debug(f"Error checking collection {mt_collection}: {e}")
return False
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
"""
Delete vectors by ID or filter from a collection with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
must_conditions = [tenant_filter]
should_conditions = []
if ids:
for id_value in ids:
should_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
)
elif filter:
for key, value in filter.items():
must_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
)
try:
# Try to delete directly - most of the time collection should exist
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
)
return update_result
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, nothing to delete"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
"""
Search for the nearest neighbor items based on the vectors with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get the vector dimension from the query vector
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
try:
# Try the search operation directly - most of the time collection should exist
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Ensure vector dimensions match the collection
collection_dim = self.client.get_collection(
mt_collection
).config.params.vectors.size
if collection_dim != dimension:
if collection_dim < dimension:
vectors = [vector[:collection_dim] for vector in vectors]
else:
vectors = [
vector + [0] * (collection_dim - dimension)
for vector in vectors
]
# Search with tenant filter
prefetch_query = models.Prefetch(
filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
prefetch=prefetch_query,
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[
[(point.score + 1.0) / 2.0 for point in query_response.points]
],
)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, search returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error searching collection '{collection_name}': {e}")
return None
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
"""
Query points with filters and tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Set default limit if not provided
if limit is None:
limit = NO_LIMIT
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Create metadata filters
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
# Combine tenant filter with metadata filters
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
try:
# Try the query directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=combined_filter,
limit=limit,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, query returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and re-raise
log.exception(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Get all items in a collection with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try to get points directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error getting collection '{collection_name}': {e}")
return None
def _handle_operation_with_error_retry(
self, operation_name, mt_collection, points, dimension
):
"""
Private helper to handle common error cases for insert and upsert operations.
Args:
operation_name: 'insert' or 'upsert'
mt_collection: The multi-tenant collection name
points: The vector points to insert/upsert
dimension: The dimension of the vectors
Returns:
The operation result (for upsert) or None (for insert)
"""
try:
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
except (UnexpectedResponse, grpc.RpcError) as e:
# Handle collection not found
if self._is_collection_not_found_error(e):
log.info(
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
)
# Create collection with correct dimensions from our vectors
self._create_multi_tenant_collection_if_not_exists(
mt_collection_name=mt_collection, dimension=dimension
)
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
# Handle dimension mismatch
elif self._is_dimension_mismatch_error(e):
# For dimension errors, the collection must exist, so get its configuration
mt_collection_info = self.client.get_collection(mt_collection)
existing_size = mt_collection_info.config.params.vectors.size
log.info(
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
)
if existing_size < dimension:
# Truncate vectors to fit
log.info(
f"Truncating vectors from {dimension} to {existing_size} dimensions"
)
points = [
PointStruct(
id=point.id,
vector=point.vector[:existing_size],
payload=point.payload,
)
for point in points
]
elif existing_size > dimension:
# Pad vectors with zeros
log.info(
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
)
points = [
PointStruct(
id=point.id,
vector=point.vector
+ [0] * (existing_size - len(point.vector)),
payload=point.payload,
)
for point in points
]
# Try operation again with adjusted dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
else:
# Not a known error we can handle, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unhandled Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def insert(self, collection_name: str, items: list[VectorItem]):
"""
Insert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"insert", mt_collection, points, dimension
)
def upsert(self, collection_name: str, items: list[VectorItem]):
"""
Upsert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"upsert", mt_collection, points, dimension
)
def reset(self):
"""
Reset the database by deleting all collections.
"""
if not self.client:
return None
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)
def delete_collection(self, collection_name: str):
"""
Delete a collection.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
field_conditions = [tenant_filter]
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions)
),
)
if self.client.get_collection(mt_collection).points_count == 0:
self.client.delete_collection(mt_collection)
return update_result

View file

@ -0,0 +1,55 @@
from open_webui.retrieval.vector.main import VectorDBBase
from open_webui.retrieval.vector.type import VectorType
from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
class Vector:
@staticmethod
def get_vector(vector_type: str) -> VectorDBBase:
"""
get vector db instance by vector type
"""
match vector_type:
case VectorType.MILVUS:
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
return MilvusClient()
case VectorType.QDRANT:
if ENABLE_QDRANT_MULTITENANCY_MODE:
from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (
QdrantClient,
)
return QdrantClient()
else:
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
return QdrantClient()
case VectorType.PINECONE:
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
return PineconeClient()
case VectorType.OPENSEARCH:
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
return OpenSearchClient()
case VectorType.PGVECTOR:
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
return PgvectorClient()
case VectorType.ELASTICSEARCH:
from open_webui.retrieval.vector.dbs.elasticsearch import (
ElasticsearchClient,
)
return ElasticsearchClient()
case VectorType.CHROMA:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
return ChromaClient()
case _:
raise ValueError(f"Unsupported vector type: {vector_type}")
VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)

View file

@ -1,5 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List, Any from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
class VectorItem(BaseModel): class VectorItem(BaseModel):
@ -17,3 +18,69 @@ class GetResult(BaseModel):
class SearchResult(GetResult): class SearchResult(GetResult):
distances: Optional[List[List[float | int]]] distances: Optional[List[List[float | int]]]
class VectorDBBase(ABC):
"""
Abstract base class for all vector database backends.
Implementations of this class provide methods for collection management,
vector insertion, deletion, similarity search, and metadata filtering.
Any custom vector database integration must inherit from this class and
implement all abstract methods.
"""
@abstractmethod
def has_collection(self, collection_name: str) -> bool:
"""Check if the collection exists in the vector DB."""
pass
@abstractmethod
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection from the vector DB."""
pass
@abstractmethod
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert a list of vector items into a collection."""
pass
@abstractmethod
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert or update vector items in a collection."""
pass
@abstractmethod
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
pass
@abstractmethod
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""Query vectors from a collection using metadata filter."""
pass
@abstractmethod
def get(self, collection_name: str) -> Optional[GetResult]:
"""Retrieve all vectors from a collection."""
pass
@abstractmethod
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by ID or filter from a collection."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by removing all collections or those matching a condition."""
pass

View file

@ -0,0 +1,11 @@
from enum import StrEnum
class VectorType(StrEnum):
MILVUS = "milvus"
QDRANT = "qdrant"
CHROMA = "chroma"
PINECONE = "pinecone"
ELASTICSEARCH = "elasticsearch"
OPENSEARCH = "opensearch"
PGVECTOR = "pgvector"

View file

@ -3,6 +3,7 @@ from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from duckduckgo_search.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -22,29 +23,24 @@ def search_duckduckgo(
list[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
# Use the DDGS context manager to create a DDGS object # Use the DDGS context manager to create a DDGS object
search_results = []
with DDGS() as ddgs: with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search # Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text( try:
query, safesearch="moderate", max_results=count, backend="api" search_results = ddgs.text(
query, safesearch="moderate", max_results=count, backend="lite"
) )
# Check if there are search results except RatelimitException as e:
if ddgs_gen: log.error(f"RatelimitException: {e}")
# Convert the search results into a list if filter_list:
search_results = [r for r in ddgs_gen] search_results = get_filtered_results(search_results, filter_list)
# Create an empty list to store the SearchResult objects # Return the list of search results
results = [] return [
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult( SearchResult(
link=result["href"], link=result["href"],
title=result.get("title"), title=result.get("title"),
snippet=result.get("body"), snippet=result.get("body"),
) )
) for result in search_results
if filter_list: ]
results = get_filtered_results(results, filter_list)
# Return the list of search results
return results

View file

@ -0,0 +1,47 @@
import logging
from typing import Optional, List
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_external(
external_url: str,
external_api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
try:
response = requests.post(
external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
},
json={
"query": query,
"count": count,
},
)
response.raise_for_status()
results = response.json()
if filter_list:
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.get("link"),
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]
log.info(f"External search results: {results}")
return results
except Exception as e:
log.error(f"Error in External search: {e}")
return []

View file

@ -0,0 +1,49 @@
import logging
from typing import Optional, List
from urllib.parse import urljoin
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_firecrawl(
firecrawl_url: str,
firecrawl_api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
try:
firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
response = requests.post(
firecrawl_search_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {firecrawl_api_key}",
},
json={
"query": query,
"limit": count,
},
)
response.raise_for_status()
results = response.json().get("data", [])
if filter_list:
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.get("url"),
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]
log.info(f"External search results: {results}")
return results
except Exception as e:
log.error(f"Error in External search: {e}")
return []

View file

@ -0,0 +1,87 @@
import logging
from typing import Optional, List
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_perplexity(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Perplexity API key
query (str): The query to search for
count (int): Maximum number of results to return
"""
# Handle PersistentConfig object
if hasattr(api_key, "__str__"):
api_key = str(api_key)
try:
url = "https://api.perplexity.ai/chat/completions"
# Create payload for the API call
payload = {
"model": "sonar",
"messages": [
{
"role": "system",
"content": "You are a search assistant. Provide factual information with citations.",
},
{"role": "user", "content": query},
],
"temperature": 0.2, # Lower temperature for more factual responses
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# Make the API request
response = requests.request("POST", url, json=payload, headers=headers)
# Parse the JSON response
json_response = response.json()
# Extract citations from the response
citations = json_response.get("citations", [])
# Create search results from citations
results = []
for i, citation in enumerate(citations[:count]):
# Extract content from the response to use as snippet
content = ""
if "choices" in json_response and json_response["choices"]:
if i == 0:
content = json_response["choices"][0]["message"]["content"]
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
results.append(result)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
for result in results[:count]
]
except Exception as e:
log.error(f"Error searching with Perplexity API: {e}")
return []

View file

@ -42,7 +42,9 @@ def search_searchapi(
results = get_filtered_results(results, filter_list) results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"] link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
) )
for result in results[:count] for result in results[:count]
] ]

View file

@ -42,7 +42,9 @@ def search_serpapi(
results = get_filtered_results(results, filter_list) results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"] link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
) )
for result in results[:count] for result in results[:count]
] ]

View file

@ -0,0 +1,60 @@
import logging
import json
from typing import Optional, List
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_sougou(
sougou_api_sid: str,
sougou_api_sk: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
from tencentcloud.common.common_client import CommonClient
from tencentcloud.common import credential
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
try:
cred = credential.Credential(sougou_api_sid, sougou_api_sk)
http_profile = HttpProfile()
http_profile.endpoint = "tms.tencentcloudapi.com"
client_profile = ClientProfile()
client_profile.http_profile = http_profile
params = json.dumps({"Query": query, "Cnt": 20})
common_client = CommonClient(
"tms", "2020-12-29", cred, "", profile=client_profile
)
results = [
json.loads(page)
for page in common_client.call_json("SearchPro", json.loads(params))[
"Response"
]["Pages"]
]
sorted_results = sorted(
results, key=lambda x: x.get("scour", 0.0), reverse=True
)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result.get("url"),
title=result.get("title"),
snippet=result.get("passage"),
)
for result in sorted_results[:count]
]
except TencentCloudSDKException as err:
log.error(f"Error in Sougou search: {err}")
return []

View file

@ -1,32 +1,45 @@
import logging import logging
from typing import Optional
import requests import requests
from open_webui.retrieval.web.main import SearchResult from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: def search_tavily(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
# **kwargs,
) -> list[SearchResult]:
"""Search using Tavily's Search API and return the results as a list of SearchResult objects. """Search using Tavily's Search API and return the results as a list of SearchResult objects.
Args: Args:
api_key (str): A Tavily Search API key api_key (str): A Tavily Search API key
query (str): The query to search for query (str): The query to search for
count (int): The maximum number of results to return
Returns: Returns:
list[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
url = "https://api.tavily.com/search" url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key} headers = {
"Content-Type": "application/json",
response = requests.post(url, json=data) "Authorization": f"Bearer {api_key}",
}
data = {"query": query, "max_results": count}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status() response.raise_for_status()
json_response = response.json() json_response = response.json()
raw_search_results = json_response.get("results", []) results = json_response.get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
@ -34,5 +47,5 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
title=result.get("title", ""), title=result.get("title", ""),
snippet=result.get("content"), snippet=result.get("content"),
) )
for result in raw_search_results[:count] for result in results
] ]

View file

@ -1,19 +1,45 @@
import socket import asyncio
import urllib.parse
import validators
from typing import Union, Sequence, Iterator
from langchain_community.document_loaders import (
WebBaseLoader,
)
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
from open_webui.env import SRC_LOG_LEVELS
import logging import logging
import socket
import ssl
import urllib.parse
import urllib.request
from collections import defaultdict
from datetime import datetime, time, timedelta
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
Literal,
)
import aiohttp
import certifi
import validators
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
PLAYWRIGHT_WS_URL,
PLAYWRIGHT_TIMEOUT,
WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
TAVILY_API_KEY,
TAVILY_EXTRACT_DEPTH,
EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY,
)
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -65,9 +91,472 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url):
metadata = {"source": url}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", "No description found.")
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
return metadata
def verify_ssl_cert(url: str) -> bool:
"""Verify SSL certificate for the given URL."""
if not url.startswith("https://"):
return True
try:
hostname = url.split("://")[-1].split("/")[0]
context = ssl.create_default_context(cafile=certifi.where())
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
s.connect((hostname, 443))
return True
except ssl.SSLError:
return False
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
return False
class RateLimitMixin:
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
class URLProcessingMixin:
def _verify_ssl_cert(self, url: str) -> bool:
"""Verify SSL certificate for a URL."""
return verify_ssl_cert(url)
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def __init__(
self,
web_paths,
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
mode: Literal["crawl", "scrape", "map"] = "scrape",
proxy: Optional[Dict[str, str]] = None,
params: Optional[Dict] = None,
):
"""Concurrent document loader for FireCrawl operations.
Executes multiple FireCrawlLoader instances concurrently using thread pooling
to improve bulk processing efficiency.
Args:
web_paths: List of URLs/paths to process.
verify_ssl: If True, verify SSL certificates.
trust_env: If True, use proxy settings from environment variables.
requests_per_second: Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
api_key: API key for FireCrawl service. Defaults to None
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
mode: Operation mode selection:
- 'crawl': Website crawling mode (default)
- 'scrape': Direct page scraping
- 'map': Site map generation
proxy: Proxy override settings for the FireCrawl API.
params: The parameters to pass to the Firecrawl API.
Examples include crawlerOptions.
For more details, visit: https://github.com/mendableai/firecrawl-py
"""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
self.web_paths = web_paths
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.trust_env = trust_env
self.continue_on_failure = continue_on_failure
self.api_key = api_key
self.api_url = api_url
self.mode = mode
self.params = params
def lazy_load(self) -> Iterator[Document]:
"""Load documents concurrently using FireCrawl."""
for url in self.web_paths:
try:
self._safe_process_url_sync(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
for document in loader.lazy_load():
if not document.metadata.get("source"):
document.metadata["source"] = document.metadata.get("sourceURL")
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
async def alazy_load(self):
"""Async version of lazy_load."""
for url in self.web_paths:
try:
await self._safe_process_url(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
async for document in loader.alazy_load():
if not document.metadata.get("source"):
document.metadata["source"] = document.metadata.get("sourceURL")
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def __init__(
self,
web_paths: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
continue_on_failure: bool = True,
requests_per_second: Optional[float] = None,
verify_ssl: bool = True,
trust_env: bool = False,
proxy: Optional[Dict[str, str]] = None,
):
"""Initialize SafeTavilyLoader with rate limiting and SSL verification support.
Args:
web_paths: List of URLs/paths to process.
api_key: The Tavily API key.
extract_depth: Depth of extraction ("basic" or "advanced").
continue_on_failure: Whether to continue if extraction of a URL fails.
requests_per_second: Number of requests per second to limit to.
verify_ssl: If True, verify SSL certificates.
trust_env: If True, use proxy settings from environment variables.
proxy: Optional proxy configuration.
"""
# Initialize proxy configuration if using environment variables
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
# Store parameters for creating TavilyLoader instances
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
self.api_key = api_key
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.verify_ssl = verify_ssl
self.trust_env = trust_env
self.proxy = proxy
# Add rate limiting
self.requests_per_second = requests_per_second
self.last_request_time = None
def lazy_load(self) -> Iterator[Document]:
"""Load documents with rate limiting support, delegating to TavilyLoader."""
valid_urls = []
for url in self.web_paths:
try:
self._safe_process_url_sync(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error extracting content from URLs: {e}")
else:
raise e
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async version with rate limiting and SSL verification."""
valid_urls = []
for url in self.web_paths:
try:
await self._safe_process_url(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
async for document in loader.alazy_load():
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading URLs: {e}")
else:
raise e
class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
Attributes:
web_paths (List[str]): List of URLs to load.
verify_ssl (bool): If True, verify SSL certificates.
trust_env (bool): If True, use proxy settings from environment variables.
requests_per_second (Optional[float]): Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
headless (bool): If True, the browser will run in headless mode.
proxy (dict): Proxy override settings for the Playwright session.
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
playwright_timeout (Optional[int]): Maximum operation time in milliseconds.
"""
def __init__(
self,
web_paths: List[str],
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
headless: bool = True,
remove_selectors: Optional[List[str]] = None,
proxy: Optional[Dict[str, str]] = None,
playwright_ws_url: Optional[str] = None,
playwright_timeout: Optional[int] = 10000,
):
"""Initialize with additional safety parameters and remote browser support."""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
super().__init__(
urls=web_paths,
continue_on_failure=continue_on_failure,
headless=headless if playwright_ws_url is None else False,
remove_selectors=remove_selectors,
proxy=proxy,
)
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.playwright_ws_url = playwright_ws_url
self.trust_env = trust_env
self.playwright_timeout = playwright_timeout
def lazy_load(self) -> Iterator[Document]:
"""Safely load URLs synchronously with support for remote browser."""
from playwright.sync_api import sync_playwright
with sync_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = p.chromium.connect(self.playwright_ws_url)
else:
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
for url in self.urls:
try:
self._safe_process_url_sync(url)
page = browser.new_page()
response = page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = self.evaluator.evaluate(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
browser.close()
async def alazy_load(self) -> AsyncIterator[Document]:
"""Safely load URLs asynchronously with support for remote browser."""
from playwright.async_api import async_playwright
async with async_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = await p.chromium.connect(self.playwright_ws_url)
else:
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy
)
for url in self.urls:
try:
await self._safe_process_url(url)
page = await browser.new_page()
response = await page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = await self.evaluator.evaluate_async(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
await browser.close()
class SafeWebBaseLoader(WebBaseLoader): class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs.""" """WebBaseLoader with enhanced error handling for URLs."""
def __init__(self, trust_env: bool = False, *args, **kwargs):
"""Initialize SafeWebBaseLoader
Args:
trust_env (bool, optional): set to True if using proxy to make web requests, for example
using http(s)_proxy environment variables. Defaults to False.
"""
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
kwargs: Dict = dict(
headers=self.session.headers,
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
async with session.get(
url,
**(self.requests_kwargs | kwargs),
) as response:
if self.raise_for_status:
response.raise_for_status()
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
def _unpack_fetch_results(
self, results: Any, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Unpack fetch results into BeautifulSoup objects."""
from bs4 import BeautifulSoup
final_results = []
for i, result in enumerate(results):
url = urls[i]
if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
return final_results
async def ascrape_all(
self, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Async fetch all urls, then return soups for all results."""
results = await self.fetch_all(urls)
return self._unpack_fetch_results(results, urls, parser=parser)
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling.""" """Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths: for path in self.web_paths:
@ -76,6 +565,18 @@ class SafeWebBaseLoader(WebBaseLoader):
text = soup.get_text(**self.bs_get_text_kwargs) text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata # Build metadata
metadata = extract_metadata(soup, path)
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.exception(f"Error loading {path}: {e}")
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
results = await self.ascrape_all(self.web_paths)
for path, soup in zip(self.web_paths, results):
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = {"source": path} metadata = {"source": path}
if title := soup.find("title"): if title := soup.find("title"):
metadata["title"] = title.get_text() metadata["title"] = title.get_text()
@ -85,24 +586,65 @@ class SafeWebBaseLoader(WebBaseLoader):
) )
if html := soup.find("html"): if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.") metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata) yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL async def aload(self) -> list[Document]:
log.error(f"Error loading {path}: {e}") """Load data into Document objects."""
return [document async for document in self.alazy_load()]
def get_web_loader( def get_web_loader(
urls: Union[str, Sequence[str]], urls: Union[str, Sequence[str]],
verify_ssl: bool = True, verify_ssl: bool = True,
requests_per_second: int = 2, requests_per_second: int = 2,
trust_env: bool = False,
): ):
# Check if the URLs are valid # Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
return SafeWebBaseLoader( web_loader_args = {
safe_urls, "web_paths": safe_urls,
verify_ssl=verify_ssl, "verify_ssl": verify_ssl,
requests_per_second=requests_per_second, "requests_per_second": requests_per_second,
continue_on_failure=True, "continue_on_failure": True,
"trust_env": trust_env,
}
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
WebLoaderClass = SafeWebBaseLoader
if WEB_LOADER_ENGINE.value == "playwright":
WebLoaderClass = SafePlaywrightURLLoader
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000
if PLAYWRIGHT_WS_URL.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value
if WEB_LOADER_ENGINE.value == "firecrawl":
WebLoaderClass = SafeFireCrawlLoader
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
if WEB_LOADER_ENGINE.value == "tavily":
WebLoaderClass = SafeTavilyLoader
web_loader_args["api_key"] = TAVILY_API_KEY.value
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
if WEB_LOADER_ENGINE.value == "external":
WebLoaderClass = ExternalWebLoader
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
if WebLoaderClass:
web_loader = WebLoaderClass(**web_loader_args)
log.debug(
"Using WEB_LOADER_ENGINE %s for %s URLs",
web_loader.__class__.__name__,
len(safe_urls),
)
return web_loader
else:
raise ValueError(
f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. "
"Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'."
) )

View file

@ -0,0 +1,87 @@
import logging
from typing import Optional
import requests
from requests.auth import HTTPDigestAuth
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_yacy(
query_url: str,
username: Optional[str],
password: Optional[str],
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""
Search a Yacy instance for a given query and return the results as a list of SearchResult objects.
The function accepts username and password for authenticating to Yacy.
Args:
query_url (str): The base URL of the Yacy server.
username (str): Optional YaCy username.
password (str): Optional YaCy password.
query (str): The search term or question to find in the Yacy database.
count (int): The maximum number of results to retrieve from the search.
Returns:
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
"""
# Use authentication if either username or password is set
yacy_auth = None
if username or password:
yacy_auth = HTTPDigestAuth(username, password)
params = {
"query": query,
"contentdom": "text",
"resource": "global",
"maximumRecords": count,
"nav": "none",
}
# Check if provided a json API URL
if not query_url.endswith("yacysearch.json"):
# Strip all query parameters from the URL
query_url = query_url.rstrip("/") + "/yacysearch.json"
log.debug(f"searching {query_url}")
response = requests.get(
query_url,
auth=yacy_auth,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
"Accept-Encoding": "gzip, deflate",
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
params=params,
)
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("channels", [{}])[0].get("items", [])
sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in sorted_results[:count]
]

View file

@ -7,6 +7,9 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from pydub import AudioSegment from pydub import AudioSegment
from pydub.silence import split_on_silence from pydub.silence import split_on_silence
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import aiohttp import aiohttp
import aiofiles import aiofiles
@ -17,6 +20,7 @@ from fastapi import (
Depends, Depends,
FastAPI, FastAPI,
File, File,
Form,
HTTPException, HTTPException,
Request, Request,
UploadFile, UploadFile,
@ -33,10 +37,13 @@ from open_webui.config import (
WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
CACHE_DIR, CACHE_DIR,
WHISPER_LANGUAGE,
) )
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ( from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
ENV, ENV,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
DEVICE_TYPE, DEVICE_TYPE,
@ -47,13 +54,15 @@ from open_webui.env import (
router = APIRouter() router = APIRouter()
# Constants # Constants
MAX_FILE_SIZE_MB = 25 MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
AZURE_MAX_FILE_SIZE_MB = 200
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"]) log.setLevel(SRC_LOG_LEVELS["AUDIO"])
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@ -67,27 +76,47 @@ from pydub import AudioSegment
from pydub.utils import mediainfo from pydub.utils import mediainfo
def is_mp4_audio(file_path): def is_audio_conversion_required(file_path):
"""Check if the given file is an MP4 audio file.""" """
Check if the given audio file needs conversion to mp3.
"""
SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
print(f"File not found: {file_path}") log.error(f"File not found: {file_path}")
return False return False
try:
info = mediainfo(file_path) info = mediainfo(file_path)
if ( codec_name = info.get("codec_name", "").lower()
info.get("codec_name") == "aac" codec_type = info.get("codec_type", "").lower()
and info.get("codec_type") == "audio" codec_tag_string = info.get("codec_tag_string", "").lower()
and info.get("codec_tag_string") == "mp4a"
): if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
# File is AAC/mp4a audio, recommend mp3 conversion
return True return True
# If the codec name is in the supported formats
if codec_name in SUPPORTED_FORMATS:
return False
return True
except Exception as e:
log.error(f"Error getting audio format: {e}")
return False return False
def convert_mp4_to_wav(file_path, output_path): def convert_audio_to_mp3(file_path):
"""Convert MP4 audio file to WAV format.""" """Convert audio file to mp3 format."""
audio = AudioSegment.from_file(file_path, format="mp4") try:
audio.export(output_path, format="wav") output_path = os.path.splitext(file_path)[0] + ".mp3"
print(f"Converted {file_path} to {output_path}") audio = AudioSegment.from_file(file_path)
audio.export(output_path, format="mp3")
log.info(f"Converted {file_path} to {output_path}")
return output_path
except Exception as e:
log.error(f"Error converting audio file: {e}")
return None
def set_faster_whisper_model(model: str, auto_update: bool = False): def set_faster_whisper_model(model: str, auto_update: bool = False):
@ -130,6 +159,7 @@ class TTSConfigForm(BaseModel):
VOICE: str VOICE: str
SPLIT_ON: str SPLIT_ON: str
AZURE_SPEECH_REGION: str AZURE_SPEECH_REGION: str
AZURE_SPEECH_BASE_URL: str
AZURE_SPEECH_OUTPUT_FORMAT: str AZURE_SPEECH_OUTPUT_FORMAT: str
@ -140,6 +170,11 @@ class STTConfigForm(BaseModel):
MODEL: str MODEL: str
WHISPER_MODEL: str WHISPER_MODEL: str
DEEPGRAM_API_KEY: str DEEPGRAM_API_KEY: str
AZURE_API_KEY: str
AZURE_REGION: str
AZURE_LOCALES: str
AZURE_BASE_URL: str
AZURE_MAX_SPEAKERS: str
class AudioConfigUpdateForm(BaseModel): class AudioConfigUpdateForm(BaseModel):
@ -159,6 +194,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"VOICE": request.app.state.config.TTS_VOICE, "VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
}, },
"stt": { "stt": {
@ -168,6 +204,11 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"MODEL": request.app.state.config.STT_MODEL, "MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY, "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
}, },
} }
@ -184,6 +225,9 @@ async def update_audio_config(
request.app.state.config.TTS_VOICE = form_data.tts.VOICE request.app.state.config.TTS_VOICE = form_data.tts.VOICE
request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
form_data.tts.AZURE_SPEECH_BASE_URL
)
request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
) )
@ -194,6 +238,13 @@ async def update_audio_config(
request.app.state.config.STT_MODEL = form_data.stt.MODEL request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
form_data.stt.AZURE_MAX_SPEAKERS
)
if request.app.state.config.STT_ENGINE == "": if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model( request.app.state.faster_whisper_model = set_faster_whisper_model(
@ -210,6 +261,7 @@ async def update_audio_config(
"VOICE": request.app.state.config.TTS_VOICE, "VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
}, },
"stt": { "stt": {
@ -219,6 +271,11 @@ async def update_audio_config(
"MODEL": request.app.state.config.STT_MODEL, "MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY, "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
}, },
} }
@ -265,8 +322,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
payload["model"] = request.app.state.config.TTS_MODEL payload["model"] = request.app.state.config.TTS_MODEL
try: try:
# print(payload) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post( async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload, json=payload,
@ -284,6 +343,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
else {} else {}
), ),
}, },
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
r.raise_for_status() r.raise_for_status()
@ -309,7 +369,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}" detail = f"External: {e}"
raise HTTPException( raise HTTPException(
status_code=getattr(r, "status", 500), status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error", detail=detail if detail else "Open WebUI: Server Connection Error",
) )
@ -323,7 +383,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
) )
try: try:
async with aiohttp.ClientSession() as session: timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post( async with session.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
json={ json={
@ -336,6 +399,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"Content-Type": "application/json", "Content-Type": "application/json",
"xi-api-key": request.app.state.config.TTS_API_KEY, "xi-api-key": request.app.state.config.TTS_API_KEY,
}, },
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
r.raise_for_status() r.raise_for_status()
@ -360,7 +424,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}" detail = f"External: {e}"
raise HTTPException( raise HTTPException(
status_code=getattr(r, "status", 500), status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error", detail=detail if detail else "Open WebUI: Server Connection Error",
) )
@ -371,7 +435,8 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e) log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload") raise HTTPException(status_code=400, detail="Invalid JSON payload")
region = request.app.state.config.TTS_AZURE_SPEECH_REGION region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
language = request.app.state.config.TTS_VOICE language = request.app.state.config.TTS_VOICE
locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1]) locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
@ -380,15 +445,20 @@ async def speech(request: Request, user=Depends(get_verified_user)):
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}"> data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice> <voice name="{language}">{payload["input"]}</voice>
</speak>""" </speak>"""
async with aiohttp.ClientSession() as session: timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post( async with session.post(
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1", (base_url or f"https://{region}.tts.speech.microsoft.com")
+ "/cognitiveservices/v1",
headers={ headers={
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/ssml+xml", "Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format, "X-Microsoft-OutputFormat": output_format,
}, },
data=data, data=data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
r.raise_for_status() r.raise_for_status()
@ -413,7 +483,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}" detail = f"External: {e}"
raise HTTPException( raise HTTPException(
status_code=getattr(r, "status", 500), status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error", detail=detail if detail else "Open WebUI: Server Connection Error",
) )
@ -457,12 +527,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
def transcribe(request: Request, file_path): def transcription_handler(request, file_path, metadata):
print("transcribe", file_path)
filename = os.path.basename(file_path) filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
id = filename.split(".")[0] id = filename.split(".")[0]
metadata = metadata or {}
if request.app.state.config.STT_ENGINE == "": if request.app.state.config.STT_ENGINE == "":
if request.app.state.faster_whisper_model is None: if request.app.state.faster_whisper_model is None:
request.app.state.faster_whisper_model = set_faster_whisper_model( request.app.state.faster_whisper_model = set_faster_whisper_model(
@ -470,7 +541,12 @@ def transcribe(request: Request, file_path):
) )
model = request.app.state.faster_whisper_model model = request.app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5) segments, info = model.transcribe(
file_path,
beam_size=5,
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
language=metadata.get("language") or WHISPER_LANGUAGE,
)
log.info( log.info(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
% (info.language, info.language_probability) % (info.language, info.language_probability)
@ -487,11 +563,6 @@ def transcribe(request: Request, file_path):
log.debug(data) log.debug(data)
return data return data
elif request.app.state.config.STT_ENGINE == "openai": elif request.app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
r = None r = None
try: try:
r = requests.post( r = requests.post(
@ -500,7 +571,14 @@ def transcribe(request: Request, file_path):
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
}, },
files={"file": (filename, open(file_path, "rb"))}, files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL}, data={
"model": request.app.state.config.STT_MODEL,
**(
{"language": metadata.get("language")}
if metadata.get("language")
else {}
),
},
) )
r.raise_for_status() r.raise_for_status()
@ -589,34 +667,254 @@ def transcribe(request: Request, file_path):
detail = f"External: {e}" detail = f"External: {e}"
raise Exception(detail if detail else "Open WebUI: Server Connection Error") raise Exception(detail if detail else "Open WebUI: Server Connection Error")
elif request.app.state.config.STT_ENGINE == "azure":
# Check file exists and size
if not os.path.exists(file_path):
raise HTTPException(status_code=400, detail="Audio file not found")
# Check file size (Azure has a larger limit of 200MB)
file_size = os.path.getsize(file_path)
if file_size > AZURE_MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
)
api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
# IF NO LOCALES, USE DEFAULTS
if len(locales) < 2:
locales = [
"en-US",
"es-ES",
"es-MX",
"fr-FR",
"hi-IN",
"it-IT",
"de-DE",
"en-GB",
"en-IN",
"ja-JP",
"ko-KR",
"pt-BR",
"zh-CN",
]
locales = ",".join(locales)
if not api_key or not region:
raise HTTPException(
status_code=400,
detail="Azure API key is required for Azure STT",
)
r = None
try:
# Prepare the request
data = {
"definition": json.dumps(
{
"locales": locales.split(","),
"diarization": {"maxSpeakers": max_speakers, "enabled": True},
}
if locales
else {}
)
}
url = (
base_url or f"https://{region}.api.cognitive.microsoft.com"
) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
# Use context manager to ensure file is properly closed
with open(file_path, "rb") as audio_file:
r = requests.post(
url=url,
files={"audio": audio_file},
data=data,
headers={
"Ocp-Apim-Subscription-Key": api_key,
},
)
r.raise_for_status()
response = r.json()
# Extract transcript from response
if not response.get("combinedPhrases"):
raise ValueError("No transcription found in response")
# Get the full transcript from combinedPhrases
transcript = response["combinedPhrases"][0].get("text", "").strip()
if not transcript:
raise ValueError("Empty transcript in response")
data = {"text": transcript}
# Save transcript to json file (consistent with other providers)
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
log.debug(data)
return data
except (KeyError, IndexError, ValueError) as e:
log.exception("Error parsing Azure response")
raise HTTPException(
status_code=500,
detail=f"Failed to parse Azure response: {str(e)}",
)
except requests.exceptions.RequestException as e:
log.exception(e)
detail = None
try:
if r is not None and r.status_code != 200:
res = r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status_code", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
log.info(f"transcribe: {file_path} {metadata}")
if is_audio_conversion_required(file_path):
file_path = convert_audio_to_mp3(file_path)
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
# Always produce a list of chunk paths (could be one entry if small)
try:
chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
print(f"Chunk paths: {chunk_paths}")
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
results = []
try:
with ThreadPoolExecutor() as executor:
# Submit tasks for each chunk_path
futures = [
executor.submit(transcription_handler, request, chunk_path, metadata)
for chunk_path in chunk_paths
]
# Gather results as they complete
for future in futures:
try:
results.append(future.result())
except Exception as transcribe_exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error transcribing chunk: {transcribe_exc}",
)
finally:
# Clean up only the temporary chunks, never the original file
for chunk_path in chunk_paths:
if chunk_path != file_path and os.path.isfile(chunk_path):
try:
os.remove(chunk_path)
except Exception:
pass
return {
"text": " ".join([result["text"] for result in results]),
}
def compress_audio(file_path): def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE: if os.path.getsize(file_path) > MAX_FILE_SIZE:
id = os.path.splitext(os.path.basename(file_path))[
0
] # Handles names with multiple dots
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
audio = AudioSegment.from_file(file_path) audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
if ( compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
os.path.getsize(compressed_path) > MAX_FILE_SIZE audio.export(compressed_path, format="mp3", bitrate="32k")
): # Still larger than MAX_FILE_SIZE after compression # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
return compressed_path return compressed_path
else: else:
return file_path return file_path
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
"""
Splits audio into chunks not exceeding max_bytes.
Returns a list of chunk file paths. If audio fits, returns list with original path.
"""
file_size = os.path.getsize(file_path)
if file_size <= max_bytes:
return [file_path] # Nothing to split
audio = AudioSegment.from_file(file_path)
duration_ms = len(audio)
orig_size = file_size
approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
chunks = []
start = 0
i = 0
base, _ = os.path.splitext(file_path)
while start < duration_ms:
end = min(start + approx_chunk_ms, duration_ms)
chunk = audio[start:end]
chunk_path = f"{base}_chunk_{i}.{format}"
chunk.export(chunk_path, format=format, bitrate=bitrate)
# Reduce chunk duration if still too large
while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
end = start + ((end - start) // 2)
chunk = audio[start:end]
chunk.export(chunk_path, format=format, bitrate=bitrate)
if os.path.getsize(chunk_path) > max_bytes:
os.remove(chunk_path)
raise Exception("Audio chunk cannot be reduced below max file size.")
chunks.append(chunk_path)
start = end
i += 1
return chunks
@router.post("/transcriptions") @router.post("/transcriptions")
def transcription( def transcription(
request: Request, request: Request,
file: UploadFile = File(...), file: UploadFile = File(...),
language: Optional[str] = Form(None),
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]: SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types!
if not (
file.content_type.startswith("audio/")
or file.content_type in SUPPORTED_CONTENT_TYPES
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
@ -637,19 +935,18 @@ def transcription(
f.write(contents) f.write(contents)
try: try:
try: metadata = None
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
raise HTTPException( if language:
status_code=status.HTTP_400_BAD_REQUEST, metadata = {"language": language}
detail=ERROR_MESSAGES.DEFAULT(e),
) result = transcribe(request, file_path, metadata)
return {
**result,
"filename": os.path.basename(file_path),
}
data = transcribe(request, file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -670,6 +967,21 @@ def transcription(
def get_available_models(request: Request) -> list[dict]: def get_available_models(request: Request) -> list[dict]:
available_models = [] available_models = []
if request.app.state.config.TTS_ENGINE == "openai": if request.app.state.config.TTS_ENGINE == "openai":
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
)
response.raise_for_status()
data = response.json()
available_models = data.get("models", [])
except Exception as e:
log.error(f"Error fetching models from custom endpoint: {str(e)}")
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
else:
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif request.app.state.config.TTS_ENGINE == "elevenlabs": elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try: try:
@ -701,6 +1013,29 @@ def get_available_voices(request) -> dict:
"""Returns {voice_id: voice_name} dict""" """Returns {voice_id: voice_name} dict"""
available_voices = {} available_voices = {}
if request.app.state.config.TTS_ENGINE == "openai": if request.app.state.config.TTS_ENGINE == "openai":
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
)
response.raise_for_status()
data = response.json()
voices_list = data.get("voices", [])
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
except Exception as e:
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
else:
available_voices = { available_voices = {
"alloy": "alloy", "alloy": "alloy",
"echo": "echo", "echo": "echo",
@ -720,7 +1055,10 @@ def get_available_voices(request) -> dict:
elif request.app.state.config.TTS_ENGINE == "azure": elif request.app.state.config.TTS_ENGINE == "azure":
try: try:
region = request.app.state.config.TTS_AZURE_SPEECH_REGION region = request.app.state.config.TTS_AZURE_SPEECH_REGION
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
url = (
base_url or f"https://{region}.tts.speech.microsoft.com"
) + "/cognitiveservices/voices/list"
headers = { headers = {
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
} }

View file

@ -19,38 +19,43 @@ from open_webui.models.auths import (
UserResponse, UserResponse,
) )
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.models.groups import Groups
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import ( from open_webui.env import (
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE, WEBUI_AUTH_COOKIE_SECURE,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
) )
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse, Response from fastapi.responses import RedirectResponse, Response, JSONResponse
from open_webui.config import ( from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
OPENID_PROVIDER_URL,
ENABLE_OAUTH_SIGNUP,
)
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.auth import ( from open_webui.utils.auth import (
decode_token,
create_api_key, create_api_key,
create_token, create_token,
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
get_current_user, get_current_user,
get_password_hash, get_password_hash,
get_http_authorization_cred,
) )
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from open_webui.utils.access_control import get_permissions from open_webui.utils.access_control import get_permissions
from typing import Optional, List from typing import Optional, List
from ssl import CERT_REQUIRED, PROTOCOL_TLS from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
if ENABLE_LDAP.value:
from ldap3 import Server, Connection, NONE, Tls from ldap3 import Server, Connection, NONE, Tls
from ldap3.utils.conv import escape_filter_chars from ldap3.utils.conv import escape_filter_chars
@ -73,27 +78,32 @@ class SessionUserResponse(Token, UserResponse):
async def get_session_user( async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user) request: Request, response: Response, user=Depends(get_current_user)
): ):
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
auth_header = request.headers.get("Authorization")
auth_token = get_http_authorization_cred(auth_header)
token = auth_token.credentials
data = decode_token(token)
expires_at = None expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token( if data:
data={"id": user.id}, expires_at = data.get("exp")
expires_delta=expires_delta,
)
datetime_expires_at = ( if (expires_at is not None) and int(time.time()) > expires_at:
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) raise HTTPException(
if expires_at status_code=status.HTTP_401_UNAUTHORIZED,
else None detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
# Set the cookie token # Set the cookie token
response.set_cookie( response.set_cookie(
key="token", key="token",
value=token, value=token,
expires=datetime_expires_at, expires=(
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
),
httponly=True, # Ensures the cookie is not accessible via JavaScript httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_AUTH_COOKIE_SAME_SITE, samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE, secure=WEBUI_AUTH_COOKIE_SECURE,
@ -178,6 +188,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
LDAP_VALIDATE_CERT = (
CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE
)
LDAP_CIPHERS = ( LDAP_CIPHERS = (
request.app.state.config.LDAP_CIPHERS request.app.state.config.LDAP_CIPHERS
if request.app.state.config.LDAP_CIPHERS if request.app.state.config.LDAP_CIPHERS
@ -189,14 +202,14 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
try: try:
tls = Tls( tls = Tls(
validate=CERT_REQUIRED, validate=LDAP_VALIDATE_CERT,
version=PROTOCOL_TLS, version=PROTOCOL_TLS,
ca_certs_file=LDAP_CA_CERT_FILE, ca_certs_file=LDAP_CA_CERT_FILE,
ciphers=LDAP_CIPHERS, ciphers=LDAP_CIPHERS,
) )
except Exception as e: except Exception as e:
log.error(f"An error occurred on TLS: {str(e)}") log.error(f"TLS configuration error: {str(e)}")
raise HTTPException(400, detail=str(e)) raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.")
try: try:
server = Server( server = Server(
@ -211,7 +224,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
LDAP_APP_DN, LDAP_APP_DN,
LDAP_APP_PASSWORD, LDAP_APP_PASSWORD,
auto_bind="NONE", auto_bind="NONE",
authentication="SIMPLE", authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS",
) )
if not connection_app.bind(): if not connection_app.bind():
raise HTTPException(400, detail="Application account bind failed") raise HTTPException(400, detail="Application account bind failed")
@ -226,14 +239,23 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
], ],
) )
if not search_success: if not search_success or not connection_app.entries:
raise HTTPException(400, detail="User not found in the LDAP server") raise HTTPException(400, detail="User not found in the LDAP server")
entry = connection_app.entries[0] entry = connection_app.entries[0]
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"]) email = entry[
if not mail or mail == "" or mail == "[]": f"{LDAP_ATTRIBUTE_FOR_MAIL}"
raise HTTPException(400, f"User {form_data.user} does not have mail.") ].value # retrieve the Attribute value
if not email:
raise HTTPException(400, "User does not have a valid email address.")
elif isinstance(email, str):
email = email.lower()
elif isinstance(email, list):
email = email[0].lower()
else:
email = str(email).lower()
cn = str(entry["cn"]) cn = str(entry["cn"])
user_dn = entry.entry_dn user_dn = entry.entry_dn
@ -246,19 +268,24 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
authentication="SIMPLE", authentication="SIMPLE",
) )
if not connection_user.bind(): if not connection_user.bind():
raise HTTPException(400, f"Authentication failed for {form_data.user}") raise HTTPException(400, "Authentication failed.")
user = Users.get_user_by_email(mail) user = Users.get_user_by_email(email)
if not user: if not user:
try: try:
user_count = Users.get_num_users()
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if user_count == 0
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
email=mail, password=str(uuid.uuid4()), name=cn, role=role email=email,
password=str(uuid.uuid4()),
name=cn,
role=role,
) )
if not user: if not user:
@ -269,23 +296,38 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
except HTTPException: except HTTPException:
raise raise
except Exception as err: except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) log.error(f"LDAP user creation error: {str(err)}")
raise HTTPException(
500, detail="Internal error occurred during LDAP user creation."
)
user = Auths.authenticate_user_by_trusted_header(mail) user = Auths.authenticate_user_by_email(email)
if user: if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration( expires_delta=expires_delta,
request.app.state.config.JWT_EXPIRES_IN
),
) )
# Set the cookie token # Set the cookie token
response.set_cookie( response.set_cookie(
key="token", key="token",
value=token, value=token,
expires=(
datetime.datetime.fromtimestamp(
expires_at, datetime.timezone.utc
)
if expires_at
else None
),
httponly=True, # Ensures the cookie is not accessible via JavaScript httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
) )
user_permissions = get_permissions( user_permissions = get_permissions(
@ -295,6 +337,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
"expires_at": expires_at,
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
@ -305,12 +348,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else: else:
raise HTTPException( raise HTTPException(400, "User record mismatch.")
400,
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
)
except Exception as e: except Exception as e:
raise HTTPException(400, detail=str(e)) log.error(f"LDAP authentication error: {str(e)}")
raise HTTPException(400, detail="LDAP authentication failed.")
############################ ############################
@ -324,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
trusted_name = trusted_email name = email
if WEBUI_AUTH_TRUSTED_NAME_HEADER: if WEBUI_AUTH_TRUSTED_NAME_HEADER:
trusted_name = request.headers.get( name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
) if not Users.get_user_by_email(email.lower()):
if not Users.get_user_by_email(trusted_email.lower()):
await signup( await signup(
request, request,
response, response,
SignupForm( SignupForm(email=email, password=str(uuid.uuid4()), name=name),
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
),
) )
user = Auths.authenticate_user_by_trusted_header(trusted_email)
user = Auths.authenticate_user_by_email(email)
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
group_names = request.headers.get(
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
).split(",")
group_names = [name.strip() for name in group_names if name.strip()]
if group_names:
Groups.sync_user_groups_by_group_names(user.id, group_names)
elif WEBUI_AUTH == False: elif WEBUI_AUTH == False:
admin_email = "admin@localhost" admin_email = "admin@localhost"
admin_password = "admin" admin_password = "admin"
@ -413,6 +462,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse) @router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm): async def signup(request: Request, response: Response, form_data: SignupForm):
if WEBUI_AUTH: if WEBUI_AUTH:
if ( if (
not request.app.state.config.ENABLE_SIGNUP not request.app.state.config.ENABLE_SIGNUP
@ -427,6 +477,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
user_count = Users.get_num_users()
if not validate_email_format(form_data.email.lower()): if not validate_email_format(form_data.email.lower()):
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@ -437,14 +488,15 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
try: try:
role = ( role = (
"admin" "admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
) )
if Users.get_num_users() == 0: # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
# Disable signup after the first user is created if len(form_data.password.encode("utf-8")) > 72:
request.app.state.config.ENABLE_SIGNUP = False raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
)
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
@ -484,6 +536,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
if request.app.state.config.WEBHOOK_URL: if request.app.state.config.WEBHOOK_URL:
post_webhook( post_webhook(
request.app.state.WEBUI_NAME,
request.app.state.config.WEBHOOK_URL, request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
@ -497,6 +550,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS
) )
if user_count == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
@ -511,7 +568,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
else: else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err: except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) log.error(f"Signup error: {str(err)}")
raise HTTPException(500, detail="An internal error occurred during signup.")
@router.get("/signout") @router.get("/signout")
@ -529,8 +587,14 @@ async def signout(request: Request, response: Response):
logout_url = openid_data.get("end_session_endpoint") logout_url = openid_data.get("end_session_endpoint")
if logout_url: if logout_url:
response.delete_cookie("oauth_id_token") response.delete_cookie("oauth_id_token")
return RedirectResponse(
url=f"{logout_url}?id_token_hint={oauth_id_token}" return JSONResponse(
status_code=200,
content={
"status": True,
"redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
},
headers=response.headers,
) )
else: else:
raise HTTPException( raise HTTPException(
@ -538,9 +602,25 @@ async def signout(request: Request, response: Response):
detail="Failed to fetch OpenID configuration", detail="Failed to fetch OpenID configuration",
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) log.error(f"OpenID signout error: {str(e)}")
raise HTTPException(
status_code=500,
detail="Failed to sign out from the OpenID provider.",
)
return {"status": True} if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
return JSONResponse(
status_code=200,
content={
"status": True,
"redirect_url": WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
},
headers=response.headers,
)
return JSONResponse(
status_code=200, content={"status": True}, headers=response.headers
)
############################ ############################
@ -582,7 +662,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
else: else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err: except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) log.error(f"Add user error: {str(err)}")
raise HTTPException(
500, detail="An internal error occurred while adding the user."
)
############################ ############################
@ -596,7 +679,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
admin_email = request.app.state.config.ADMIN_EMAIL admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None admin_name = None
print(admin_email, admin_name) log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email: if admin_email:
admin = Users.get_user_by_email(admin_email) admin = Users.get_user_by_email(admin_email)
@ -630,11 +713,16 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, "ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS, "API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
} }
@ -645,11 +733,16 @@ class AdminConfig(BaseModel):
ENABLE_API_KEY: bool ENABLE_API_KEY: bool
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
API_KEY_ALLOWED_ENDPOINTS: str API_KEY_ALLOWED_ENDPOINTS: str
ENABLE_CHANNELS: bool
DEFAULT_USER_ROLE: str DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool ENABLE_COMMUNITY_SHARING: bool
ENABLE_MESSAGE_RATING: bool ENABLE_MESSAGE_RATING: bool
ENABLE_CHANNELS: bool
ENABLE_NOTES: bool
ENABLE_USER_WEBHOOKS: bool
PENDING_USER_OVERLAY_TITLE: Optional[str] = None
PENDING_USER_OVERLAY_CONTENT: Optional[str] = None
RESPONSE_WATERMARK: Optional[str] = None
@router.post("/admin/config") @router.post("/admin/config")
@ -669,6 +762,7 @@ async def update_admin_config(
) )
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@ -684,6 +778,17 @@ async def update_admin_config(
) )
request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING
request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS
request.app.state.config.PENDING_USER_OVERLAY_TITLE = (
form_data.PENDING_USER_OVERLAY_TITLE
)
request.app.state.config.PENDING_USER_OVERLAY_CONTENT = (
form_data.PENDING_USER_OVERLAY_CONTENT
)
request.app.state.config.RESPONSE_WATERMARK = form_data.RESPONSE_WATERMARK
return { return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL, "WEBUI_URL": request.app.state.config.WEBUI_URL,
@ -691,11 +796,16 @@ async def update_admin_config(
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, "ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS, "API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
} }
@ -711,6 +821,7 @@ class LdapServerConfig(BaseModel):
search_filters: str = "" search_filters: str = ""
use_tls: bool = True use_tls: bool = True
certificate_path: Optional[str] = None certificate_path: Optional[str] = None
validate_cert: bool = True
ciphers: Optional[str] = "ALL" ciphers: Optional[str] = "ALL"
@ -728,6 +839,7 @@ async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS, "use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
"ciphers": request.app.state.config.LDAP_CIPHERS, "ciphers": request.app.state.config.LDAP_CIPHERS,
} }
@ -750,11 +862,6 @@ async def update_ldap_server(
if not value: if not value:
raise HTTPException(400, detail=f"Required field {key} is empty") raise HTTPException(400, detail=f"Required field {key} is empty")
if form_data.use_tls and not form_data.certificate_path:
raise HTTPException(
400, detail="TLS is enabled but certificate file path is missing"
)
request.app.state.config.LDAP_SERVER_LABEL = form_data.label request.app.state.config.LDAP_SERVER_LABEL = form_data.label
request.app.state.config.LDAP_SERVER_HOST = form_data.host request.app.state.config.LDAP_SERVER_HOST = form_data.host
request.app.state.config.LDAP_SERVER_PORT = form_data.port request.app.state.config.LDAP_SERVER_PORT = form_data.port
@ -768,6 +875,7 @@ async def update_ldap_server(
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
request.app.state.config.LDAP_USE_TLS = form_data.use_tls request.app.state.config.LDAP_USE_TLS = form_data.use_tls
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
request.app.state.config.LDAP_VALIDATE_CERT = form_data.validate_cert
request.app.state.config.LDAP_CIPHERS = form_data.ciphers request.app.state.config.LDAP_CIPHERS = form_data.ciphers
return { return {
@ -782,6 +890,7 @@ async def update_ldap_server(
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS, "use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
"ciphers": request.app.state.config.LDAP_CIPHERS, "ciphers": request.app.state.config.LDAP_CIPHERS,
} }

View file

@ -192,7 +192,7 @@ async def get_channel_messages(
############################ ############################
async def send_notification(webui_url, channel, message, active_user_ids): async def send_notification(name, webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control) users = get_users_with_access("read", channel.access_control)
for user in users: for user in users:
@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
if webhook_url: if webhook_url:
post_webhook( post_webhook(
name,
webhook_url, webhook_url,
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}", f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
{ {
@ -302,6 +303,7 @@ async def post_new_message(
background_tasks.add_task( background_tasks.add_task(
send_notification, send_notification,
request.app.state.WEBUI_NAME,
request.app.state.config.WEBUI_URL, request.app.state.config.WEBUI_URL,
channel, channel,
message, message,

View file

@ -2,6 +2,8 @@ import json
import logging import logging
from typing import Optional from typing import Optional
from open_webui.socket.main import get_event_emitter
from open_webui.models.chats import ( from open_webui.models.chats import (
ChatForm, ChatForm,
ChatImportForm, ChatImportForm,
@ -74,17 +76,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse]) @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id( async def get_user_chat_list_by_user_id(
user_id: str, user_id: str,
page: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
): ):
if not ENABLE_ADMIN_CHAT_ACCESS: if not ENABLE_ADMIN_CHAT_ACCESS:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
return Chats.get_chat_list_by_user_id( return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit user_id, include_archived=True, filter=filter, skip=skip, limit=limit
) )
@ -192,10 +211,10 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
############################ ############################
@router.get("/pinned", response_model=list[ChatResponse]) @router.get("/pinned", response_model=list[ChatTitleIdResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)): async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**chat.model_dump()) ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id) for chat in Chats.get_pinned_chats_by_user_id(user.id)
] ]
@ -265,9 +284,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
@router.get("/archived", response_model=list[ChatTitleIdResponse]) @router.get("/archived", response_model=list[ChatTitleIdResponse])
async def get_archived_session_user_chat_list( async def get_archived_session_user_chat_list(
user=Depends(get_verified_user), skip: int = 0, limit: int = 50 page: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_verified_user),
): ):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
chat_list = [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_archived_chat_list_by_user_id(
user.id,
filter=filter,
skip=skip,
limit=limit,
)
]
return chat_list
############################ ############################
@ -372,6 +419,107 @@ async def update_chat_by_id(
) )
############################
# UpdateChatMessageById
############################
class MessageForm(BaseModel):
content: str
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
async def update_chat_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.upsert_message_to_chat_by_id_and_message_id(
id,
message_id,
{
"content": form_data.content,
},
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
},
False,
)
if event_emitter:
await event_emitter(
{
"type": "chat:message",
"data": {
"chat_id": id,
"message_id": message_id,
"content": form_data.content,
},
}
)
return ChatResponse(**chat.model_dump())
############################
# SendChatMessageEventById
############################
class EventForm(BaseModel):
type: str
data: dict
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
async def send_chat_message_event_by_id(
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
}
)
try:
if event_emitter:
await event_emitter(form_data.model_dump())
else:
return False
return True
except:
return False
############################ ############################
# DeleteChatById # DeleteChatById
############################ ############################
@ -476,7 +624,12 @@ async def clone_chat_by_id(
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse]) @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
if user.role == "admin":
chat = Chats.get_chat_by_id(id)
else:
chat = Chats.get_chat_by_share_id(id) chat = Chats.get_chat_by_share_id(id)
if chat: if chat:
updated_chat = { updated_chat = {
**chat.chat, **chat.chat,
@ -530,8 +683,17 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse]) @router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_verified_user)): async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if not has_permission(
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if chat.share_id: if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)

View file

@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from typing import Optional from typing import Optional
@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import get_config, save_config from open_webui.config import get_config, save_config
from open_webui.config import BannerModel from open_webui.config import BannerModel
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
router = APIRouter() router = APIRouter()
@ -66,10 +68,86 @@ async def set_direct_connections_config(
} }
############################
# ToolServers Config
############################
class ToolServerConnection(BaseModel):
url: str
path: str
auth_type: Optional[str]
key: Optional[str]
config: Optional[dict]
model_config = ConfigDict(extra="allow")
class ToolServersConfigForm(BaseModel):
TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
@router.get("/tool_servers", response_model=ToolServersConfigForm)
async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers", response_model=ToolServersConfigForm)
async def set_tool_servers_config(
request: Request,
form_data: ToolServersConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers/verify")
async def verify_tool_servers_config(
request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
):
"""
Verify the connection to the tool server.
"""
try:
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
url = f"{form_data.url}/{form_data.path}"
return await get_tool_server_data(token, url)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to connect to the tool server: {str(e)}",
)
############################ ############################
# CodeInterpreterConfig # CodeInterpreterConfig
############################ ############################
class CodeInterpreterConfigForm(BaseModel): class CodeInterpreterConfigForm(BaseModel):
ENABLE_CODE_EXECUTION: bool
CODE_EXECUTION_ENGINE: str
CODE_EXECUTION_JUPYTER_URL: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
ENABLE_CODE_INTERPRETER: bool ENABLE_CODE_INTERPRETER: bool
CODE_INTERPRETER_ENGINE: str CODE_INTERPRETER_ENGINE: str
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str] CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
@ -77,11 +155,19 @@ class CodeInterpreterConfigForm(BaseModel):
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str] CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str] CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str] CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm) @router.get("/code_execution", response_model=CodeInterpreterConfigForm)
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)): async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
return { return {
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
@ -89,13 +175,34 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
} }
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm) @router.post("/code_execution", response_model=CodeInterpreterConfigForm)
async def set_code_interpreter_config( async def set_code_execution_config(
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user) request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
): ):
request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
form_data.CODE_EXECUTION_JUPYTER_URL
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
form_data.CODE_EXECUTION_JUPYTER_AUTH
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
)
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = ( request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
@ -116,8 +223,18 @@ async def set_code_interpreter_config(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
) )
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
)
return { return {
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
@ -125,6 +242,7 @@ async def set_code_interpreter_config(
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
} }

View file

@ -56,19 +56,35 @@ async def update_config(
} }
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str = "pending"
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class FeedbackUserResponse(FeedbackResponse): class FeedbackUserResponse(FeedbackResponse):
user: Optional[UserModel] = None user: Optional[UserResponse] = None
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse]) @router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)): async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks() feedbacks = Feedbacks.get_all_feedbacks()
return [
feedback_list = []
for feedback in feedbacks:
user = Users.get_user_by_id(feedback.user_id)
feedback_list.append(
FeedbackUserResponse( FeedbackUserResponse(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id) **feedback.model_dump(),
user=UserResponse(**user.model_dump()) if user else None,
) )
for feedback in feedbacks )
] return feedback_list
@router.delete("/feedbacks/all") @router.delete("/feedbacks/all")
@ -80,12 +96,7 @@ async def delete_all_feedbacks(user=Depends(get_admin_user)):
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel]) @router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
async def get_all_feedbacks(user=Depends(get_admin_user)): async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks() feedbacks = Feedbacks.get_all_feedbacks()
return [ return feedbacks
FeedbackModel(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
)
for feedback in feedbacks
]
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse]) @router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])

View file

@ -1,21 +1,39 @@
import logging import logging
import os import os
import uuid import uuid
import json
from fnmatch import fnmatch
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import quote from urllib.parse import quote
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi import (
APIRouter,
Depends,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
Query,
)
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.users import Users
from open_webui.models.files import ( from open_webui.models.files import (
FileForm, FileForm,
FileModel, FileModel,
FileModelResponse, FileModelResponse,
Files, Files,
) )
from open_webui.models.knowledge import Knowledges
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
from open_webui.routers.retrieval import ProcessFileForm, process_file from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.routers.audio import transcribe
from open_webui.storage.provider import Storage from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel from pydantic import BaseModel
@ -26,6 +44,39 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter() router = APIRouter()
############################
# Check if the current user has access to a file through any knowledge bases the user may be in.
############################
def has_access_to_file(
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
) -> bool:
file = Files.get_file_by_id(file_id)
log.debug(f"Checking if user has {access_type} access to file")
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_access = False
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
if knowledge_base_id:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
user.id, access_type
)
for knowledge_base in knowledge_bases:
if knowledge_base.id == knowledge_base_id:
has_access = True
break
return has_access
############################ ############################
# Upload File # Upload File
############################ ############################
@ -35,19 +86,55 @@ router = APIRouter()
def upload_file( def upload_file(
request: Request, request: Request,
file: UploadFile = File(...), file: UploadFile = File(...),
metadata: Optional[dict | str] = Form(None),
process: bool = Query(True),
internal: bool = False,
user=Depends(get_verified_user), user=Depends(get_verified_user),
file_metadata: dict = {},
): ):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
)
file_metadata = metadata if metadata else {}
try: try:
unsanitized_filename = file.filename unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename) filename = os.path.basename(unsanitized_filename)
file_extension = os.path.splitext(filename)[1]
# Remove the leading dot from the file extension
file_extension = file_extension[1:] if file_extension else ""
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
]
if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(
f"File type {file_extension} is not allowed"
),
)
# replace filename with uuid # replace filename with uuid
id = str(uuid.uuid4()) id = str(uuid.uuid4())
name = filename name = filename
filename = f"{id}_{filename}" filename = f"{id}_{filename}"
contents, file_path = Storage.upload_file(file.file, filename) tags = {
"OpenWebUI-User-Email": user.email,
"OpenWebUI-User-Id": user.id,
"OpenWebUI-User-Name": user.name,
"OpenWebUI-File-Id": id,
}
contents, file_path = Storage.upload_file(file.file, filename, tags)
file_item = Files.insert_new_file( file_item = Files.insert_new_file(
user.id, user.id,
@ -65,9 +152,30 @@ def upload_file(
} }
), ),
) )
if process:
try: try:
if file.content_type:
if file.content_type.startswith("audio/") or file.content_type in {
"video/webm"
}:
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata)
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(request, ProcessFileForm(file_id=id), user=user) process_file(request, ProcessFileForm(file_id=id), user=user)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id) file_item = Files.get_file_by_id(id=id)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -91,7 +199,7 @@ def upload_file(
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
) )
@ -101,14 +209,62 @@ def upload_file(
@router.get("/", response_model=list[FileModelResponse]) @router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user)): async def list_files(user=Depends(get_verified_user), content: bool = Query(True)):
if user.role == "admin": if user.role == "admin":
files = Files.get_files() files = Files.get_files()
else: else:
files = Files.get_files_by_user_id(user.id) files = Files.get_files_by_user_id(user.id)
if not content:
for file in files:
if "content" in file.data:
del file.data["content"]
return files return files
############################
# Search Files
############################
@router.get("/search", response_model=list[FileModelResponse])
async def search_files(
filename: str = Query(
...,
description="Filename pattern to search for. Supports wildcards such as '*.txt'",
),
content: bool = Query(True),
user=Depends(get_verified_user),
):
"""
Search for files by filename with support for wildcard patterns.
"""
# Get files according to user role
if user.role == "admin":
files = Files.get_files()
else:
files = Files.get_files_by_user_id(user.id)
# Get matching files
matching_files = [
file for file in files if fnmatch(file.filename.lower(), filename.lower())
]
if not matching_files:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No files found matching the pattern.",
)
if not content:
for file in matching_files:
if "content" in file.data:
del file.data["content"]
return matching_files
############################ ############################
# Delete All Files # Delete All Files
############################ ############################
@ -144,7 +300,17 @@ async def delete_all_files(user=Depends(get_admin_user)):
async def get_file_by_id(id: str, user=Depends(get_verified_user)): async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
return file return file
else: else:
raise HTTPException( raise HTTPException(
@ -162,7 +328,17 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
return {"content": file.data.get("content", "")} return {"content": file.data.get("content", "")}
else: else:
raise HTTPException( raise HTTPException(
@ -186,7 +362,17 @@ async def update_file_data_content_by_id(
): ):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "write", user)
):
try: try:
process_file( process_file(
request, request,
@ -212,9 +398,22 @@ async def update_file_data_content_by_id(
@router.get("/{id}/content") @router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(
id: str, user=Depends(get_verified_user), attachment: bool = Query(False)
):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
try: try:
file_path = Storage.get_file(file.path) file_path = Storage.get_file(file.path)
file_path = Path(file_path) file_path = Path(file_path)
@ -225,17 +424,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
filename = file.meta.get("name", file.filename) filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding encoded_filename = quote(filename) # RFC5987 encoding
content_type = file.meta.get("content_type")
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename)
headers = {} headers = {}
if file.meta.get("content_type") not in [
"application/pdf",
"text/plain",
]:
headers = {
**headers,
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
}
return FileResponse(file_path, headers=headers) if attachment:
headers["Content-Disposition"] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
else:
if content_type == "application/pdf" or filename.lower().endswith(
".pdf"
):
headers["Content-Disposition"] = (
f"inline; filename*=UTF-8''{encoded_filename}"
)
content_type = "application/pdf"
elif content_type != "text/plain":
headers["Content-Disposition"] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
return FileResponse(file_path, headers=headers, media_type=content_type)
else: else:
raise HTTPException( raise HTTPException(
@ -259,14 +470,32 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content/html") @router.get("/{id}/content/html")
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
file_user = Users.get_user_by_id(file.user_id)
if not file_user.role == "admin":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
try: try:
file_path = Storage.get_file(file.path) file_path = Storage.get_file(file.path)
file_path = Path(file_path) file_path = Path(file_path)
# Check if the file already exists in the cache # Check if the file already exists in the cache
if file_path.is_file(): if file_path.is_file():
print(f"file_path: {file_path}") log.info(f"file_path: {file_path}")
return FileResponse(file_path) return FileResponse(file_path)
else: else:
raise HTTPException( raise HTTPException(
@ -291,7 +520,17 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
file_path = file.path file_path = file.path
# Handle Unicode filenames # Handle Unicode filenames
@ -342,7 +581,18 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}") @router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)): async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "write", user)
):
# We should add Chroma cleanup here # We should add Chroma cleanup here
result = Files.delete_file_by_id(id) result = Files.delete_file_by_id(id)

View file

@ -20,11 +20,13 @@ from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_permission
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -228,7 +230,19 @@ async def update_folder_is_expanded_by_id(
@router.delete("/{id}") @router.delete("/{id}")
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)): async def delete_folder_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
chat_delete_permission = has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
)
if user.role != "admin" and not chat_delete_permission:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
folder = Folders.get_folder_by_id_and_user_id(id, user.id) folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder: if folder:
try: try:

View file

@ -1,4 +1,8 @@
import os import os
import re
import logging
import aiohttp
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -8,11 +12,22 @@ from open_webui.models.functions import (
FunctionResponse, FunctionResponse,
Functions, Functions,
) )
from open_webui.utils.plugin import load_function_module_by_id, replace_imports from open_webui.utils.plugin import (
load_function_module_by_id,
replace_imports,
get_function_module_from_cache,
)
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, HttpUrl
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter() router = APIRouter()
@ -36,6 +51,97 @@ async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions() return Functions.get_functions()
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_function_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
function_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the function"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": function_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
############################
# SyncFunctions
############################
class SyncFunctionsForm(FunctionForm):
functions: list[FunctionModel] = []
@router.post("/sync", response_model=Optional[FunctionModel])
async def sync_functions(
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
):
return Functions.sync_functions(user.id, form_data.functions)
############################ ############################
# CreateNewFunction # CreateNewFunction
############################ ############################
@ -68,7 +174,7 @@ async def create_new_function(
function = Functions.insert_new_function(user.id, function_type, form_data) function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id function_cache_dir = CACHE_DIR / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True) function_cache_dir.mkdir(parents=True, exist_ok=True)
if function: if function:
@ -79,7 +185,7 @@ async def create_new_function(
detail=ERROR_MESSAGES.DEFAULT("Error creating function"), detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
) )
except Exception as e: except Exception as e:
print(e) log.exception(f"Failed to create a new function: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
@ -183,7 +289,7 @@ async def update_function_by_id(
FUNCTIONS[id] = function_module FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated) log.debug(updated)
function = Functions.update_function_by_id(id, updated) function = Functions.update_function_by_id(id, updated)
@ -256,11 +362,9 @@ async def get_function_valves_spec_by_id(
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id)
if function: if function:
if id in request.app.state.FUNCTIONS: function_module, function_type, frontmatter = get_function_module_from_cache(
function_module = request.app.state.FUNCTIONS[id] request, id
else: )
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"): if hasattr(function_module, "Valves"):
Valves = function_module.Valves Valves = function_module.Valves
@ -284,11 +388,9 @@ async def update_function_valves_by_id(
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id)
if function: if function:
if id in request.app.state.FUNCTIONS: function_module, function_type, frontmatter = get_function_module_from_cache(
function_module = request.app.state.FUNCTIONS[id] request, id
else: )
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"): if hasattr(function_module, "Valves"):
Valves = function_module.Valves Valves = function_module.Valves
@ -299,7 +401,7 @@ async def update_function_valves_by_id(
Functions.update_function_valves_by_id(id, valves.model_dump()) Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump() return valves.model_dump()
except Exception as e: except Exception as e:
print(e) log.exception(f"Error updating function values by id {id}: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
@ -347,11 +449,9 @@ async def get_function_user_valves_spec_by_id(
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id)
if function: if function:
if id in request.app.state.FUNCTIONS: function_module, function_type, frontmatter = get_function_module_from_cache(
function_module = request.app.state.FUNCTIONS[id] request, id
else: )
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"): if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves UserValves = function_module.UserValves
@ -371,11 +471,9 @@ async def update_function_user_valves_by_id(
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id)
if function: if function:
if id in request.app.state.FUNCTIONS: function_module, function_type, frontmatter = get_function_module_from_cache(
function_module = request.app.state.FUNCTIONS[id] request, id
else: )
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"): if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves UserValves = function_module.UserValves
@ -388,7 +486,7 @@ async def update_function_user_valves_by_id(
) )
return user_valves.model_dump() return user_valves.model_dump()
except Exception as e: except Exception as e:
print(e) log.exception(f"Error updating function user valves by id {id}: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),

16
backend/open_webui/routers/groups.py Normal file → Executable file
View file

@ -1,7 +1,7 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import logging
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.models.groups import ( from open_webui.models.groups import (
@ -14,7 +14,13 @@ from open_webui.models.groups import (
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter() router = APIRouter()
@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[GroupResponse]) @router.post("/create", response_model=Optional[GroupResponse])
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)): async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try: try:
group = Groups.insert_new_group(user.id, form_data) group = Groups.insert_new_group(user.id, form_data)
if group: if group:
@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
detail=ERROR_MESSAGES.DEFAULT("Error creating group"), detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
) )
except Exception as e: except Exception as e:
print(e) log.exception(f"Error creating a new group: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
@ -94,7 +100,7 @@ async def update_group_by_id(
detail=ERROR_MESSAGES.DEFAULT("Error updating group"), detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
) )
except Exception as e: except Exception as e:
print(e) log.exception(f"Error updating group {id}: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
) )
except Exception as e: except Exception as e:
print(e) log.exception(f"Error deleting group {id}: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),

View file

@ -25,7 +25,7 @@ from pydantic import BaseModel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"]) log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}, },
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
} }
@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
COMFYUI_WORKFLOW_NODES: list[dict] COMFYUI_WORKFLOW_NODES: list[dict]
class GeminiConfigForm(BaseModel):
GEMINI_API_BASE_URL: str
GEMINI_API_KEY: str
class ConfigForm(BaseModel): class ConfigForm(BaseModel):
enabled: bool enabled: bool
engine: str engine: str
@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
openai: OpenAIConfigForm openai: OpenAIConfigForm
automatic1111: Automatic1111ConfigForm automatic1111: Automatic1111ConfigForm
comfyui: ComfyUIConfigForm comfyui: ComfyUIConfigForm
gemini: GeminiConfigForm
@router.post("/config/update") @router.post("/config/update")
@ -103,6 +113,11 @@ async def update_config(
) )
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
form_data.gemini.GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
request.app.state.config.AUTOMATIC1111_BASE_URL = ( request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL form_data.automatic1111.AUTOMATIC1111_BASE_URL
) )
@ -129,6 +144,8 @@ async def update_config(
request.app.state.config.COMFYUI_BASE_URL = ( request.app.state.config.COMFYUI_BASE_URL = (
form_data.comfyui.COMFYUI_BASE_URL.strip("/") form_data.comfyui.COMFYUI_BASE_URL.strip("/")
) )
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = ( request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES form_data.comfyui.COMFYUI_WORKFLOW_NODES
@ -155,6 +172,10 @@ async def update_config(
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}, },
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
} }
@ -184,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_IMAGE_GENERATION = False request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
try: try:
r = requests.get( r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
) )
r.raise_for_status() r.raise_for_status()
return True return True
@ -224,6 +253,12 @@ def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_MODEL if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2" else "dall-e-2"
) )
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "imagen-3.0-generate-002"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return ( return (
request.app.state.config.IMAGE_GENERATION_MODEL request.app.state.config.IMAGE_GENERATION_MODEL
@ -298,6 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
return [ return [
{"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"}, {"id": "dall-e-3", "name": "DALL·E 3"},
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
] ]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui # TODO - get models from comfyui
@ -322,7 +362,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
if model_node_id: if model_node_id:
model_list_key = None model_list_key = None
print(workflow[model_node_id]["class_type"]) log.info(workflow[model_node_id]["class_type"])
for key in info[workflow[model_node_id]["class_type"]]["input"][ for key in info[workflow[model_node_id]["class_type"]]["input"][
"required" "required"
]: ]:
@ -411,7 +451,7 @@ def load_url_image_data(url, headers=None):
return None return None
def upload_image(request, image_metadata, image_data, content_type, user): def upload_image(request, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type) image_format = mimetypes.guess_extension(content_type)
file = UploadFile( file = UploadFile(
file=io.BytesIO(image_data), file=io.BytesIO(image_data),
@ -420,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
"content-type": content_type, "content-type": content_type,
}, },
) )
file_item = upload_file(request, file, user, file_metadata=image_metadata) file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url return url
@ -461,7 +501,11 @@ async def image_generations(
if form_data.size if form_data.size
else request.app.state.config.IMAGE_SIZE else request.app.state.config.IMAGE_SIZE
), ),
"response_format": "b64_json", **(
{}
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
else {"response_format": "b64_json"}
),
} }
# Use asyncio.to_thread for the requests.post call # Use asyncio.to_thread for the requests.post call
@ -478,11 +522,50 @@ async def image_generations(
images = [] images = []
for image in res["data"]: for image in res["data"]:
if image_url := image.get("url", None):
image_data, content_type = load_url_image_data(image_url, headers)
else:
image_data, content_type = load_b64_image_data(image["b64_json"]) image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, data, image_data, content_type, user)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url}) images.append({"url": url})
return images return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {}
headers["Content-Type"] = "application/json"
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
model = get_image_model(request)
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["predictions"]:
image_data, content_type = load_b64_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
@ -529,9 +612,9 @@ async def image_generations(
image_data, content_type = load_url_image_data(image["url"], headers) image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image( url = upload_image(
request, request,
form_data.model_dump(exclude_none=True),
image_data, image_data,
content_type, content_type,
form_data.model_dump(exclude_none=True),
user, user,
) )
images.append({"url": url}) images.append({"url": url})
@ -541,7 +624,7 @@ async def image_generations(
or request.app.state.config.IMAGE_GENERATION_ENGINE == "" or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
): ):
if form_data.model: if form_data.model:
set_image_model(form_data.model) set_image_model(request, form_data.model)
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
@ -582,9 +665,9 @@ async def image_generations(
image_data, content_type = load_b64_image_data(image) image_data, content_type = load_b64_image_data(image)
url = upload_image( url = upload_image(
request, request,
{**data, "info": res["info"]},
image_data, image_data,
content_type, content_type,
{**data, "info": res["info"]},
user, user,
) )
images.append({"url": url}) images.append({"url": url})

View file

@ -9,8 +9,8 @@ from open_webui.models.knowledge import (
KnowledgeResponse, KnowledgeResponse,
KnowledgeUserResponse, KnowledgeUserResponse,
) )
from open_webui.models.files import Files, FileModel from open_webui.models.files import Files, FileModel, FileMetadataResponse
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import ( from open_webui.routers.retrieval import (
process_file, process_file,
ProcessFileForm, ProcessFileForm,
@ -161,13 +161,94 @@ async def create_new_knowledge(
) )
############################
# ReindexKnowledgeFiles
############################
@router.post("/reindex", response_model=bool)
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
knowledge_bases = Knowledges.get_knowledge_bases()
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
deleted_knowledge_bases = []
for knowledge_base in knowledge_bases:
# -- Robust error handling for missing or invalid data
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
log.warning(
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
)
try:
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
deleted_knowledge_bases.append(knowledge_base.id)
except Exception as e:
log.error(
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
)
continue
try:
file_ids = knowledge_base.data.get("file_ids", [])
files = Files.get_files_by_ids(file_ids)
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
VECTOR_DB_CLIENT.delete_collection(
collection_name=knowledge_base.id
)
except Exception as e:
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
continue # Skip, don't raise
failed_files = []
for file in files:
try:
process_file(
request,
ProcessFileForm(
file_id=file.id, collection_name=knowledge_base.id
),
user=user,
)
except Exception as e:
log.error(
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
)
failed_files.append({"file_id": file.id, "error": str(e)})
continue
except Exception as e:
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
# Don't raise, just continue
continue
if failed_files:
log.warning(
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
)
for failed in failed_files:
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
log.info(
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
)
return True
############################ ############################
# GetKnowledgeById # GetKnowledgeById
############################ ############################
class KnowledgeFilesResponse(KnowledgeResponse): class KnowledgeFilesResponse(KnowledgeResponse):
files: list[FileModel] files: list[FileMetadataResponse]
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse]) @router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
@ -183,7 +264,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
): ):
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_files_by_ids(file_ids) files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
@ -311,7 +392,7 @@ def add_file_to_knowledge_by_id(
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge: if knowledge:
files = Files.get_files_by_ids(file_ids) files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
@ -388,7 +469,7 @@ def update_file_from_knowledge_by_id(
data = knowledge.data or {} data = knowledge.data or {}
file_ids = data.get("file_ids", []) file_ids = data.get("file_ids", [])
files = Files.get_files_by_ids(file_ids) files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
@ -437,14 +518,24 @@ def remove_file_from_knowledge_by_id(
) )
# Remove content from the vector database # Remove content from the vector database
try:
VECTOR_DB_CLIENT.delete( VECTOR_DB_CLIENT.delete(
collection_name=knowledge.id, filter={"file_id": form_data.file_id} collection_name=knowledge.id, filter={"file_id": form_data.file_id}
) )
except Exception as e:
log.debug("This was most likely caused by bypassing embedding processing")
log.debug(e)
pass
try:
# Remove the file's collection from vector database # Remove the file's collection from vector database
file_collection = f"file-{form_data.file_id}" file_collection = f"file-{form_data.file_id}"
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection): if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection) VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
except Exception as e:
log.debug("This was most likely caused by bypassing embedding processing")
log.debug(e)
pass
# Delete file from database # Delete file from database
Files.delete_file_by_id(form_data.file_id) Files.delete_file_by_id(form_data.file_id)
@ -460,7 +551,7 @@ def remove_file_from_knowledge_by_id(
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge: if knowledge:
files = Files.get_files_by_ids(file_ids) files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
@ -614,7 +705,7 @@ def add_files_to_knowledge_batch(
) )
# Get files content # Get files content
print(f"files/batch/add - {len(form_data)} files") log.info(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = [] files: List[FileModel] = []
for form in form_data: for form in form_data:
file = Files.get_file_by_id(form.file_id) file = Files.get_file_by_id(form.file_id)
@ -656,7 +747,7 @@ def add_files_to_knowledge_batch(
error_details = [f"{err.file_id}: {err.error}" for err in result.errors] error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids), files=Files.get_file_metadatas_by_ids(existing_file_ids),
warnings={ warnings={
"message": "Some files failed to process", "message": "Some files failed to process",
"errors": error_details, "errors": error_details,
@ -664,5 +755,6 @@ def add_files_to_knowledge_batch(
) )
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids) **knowledge.model_dump(),
files=Files.get_file_metadatas_by_ids(existing_file_ids),
) )

View file

@ -4,7 +4,7 @@ import logging
from typing import Optional from typing import Optional
from open_webui.models.memories import Memories, MemoryModel from open_webui.models.memories import Memories, MemoryModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.utils.auth import get_verified_user from open_webui.utils.auth import get_verified_user
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -57,7 +57,9 @@ async def add_memory(
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), "vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"metadata": {"created_at": memory.created_at}, "metadata": {"created_at": memory.created_at},
} }
], ],
@ -82,7 +84,7 @@ async def query_memory(
): ):
results = VECTOR_DB_CLIENT.search( results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)], vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
limit=form_data.k, limit=form_data.k,
) )
@ -105,7 +107,9 @@ async def reset_memory_from_vector_db(
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), "vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,
"updated_at": memory.updated_at, "updated_at": memory.updated_at,
@ -149,7 +153,9 @@ async def update_memory_by_id(
form_data: MemoryUpdateModel, form_data: MemoryUpdateModel,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
memory = Memories.update_memory_by_id(memory_id, form_data.content) memory = Memories.update_memory_by_id_and_user_id(
memory_id, user.id, form_data.content
)
if memory is None: if memory is None:
raise HTTPException(status_code=404, detail="Memory not found") raise HTTPException(status_code=404, detail="Memory not found")
@ -161,7 +167,7 @@ async def update_memory_by_id(
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user memory.content, user=user
), ),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,

View file

@ -0,0 +1,218 @@
import json
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
from pydantic import BaseModel
from open_webui.models.users import Users, UserResponse
from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# GetNotes
############################
@router.get("/", response_model=list[NoteUserResponse])
async def get_notes(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
notes = [
NoteUserResponse(
**{
**note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
}
)
for note in Notes.get_notes_by_user_id(user.id, "write")
]
return notes
@router.get("/list", response_model=list[NoteUserResponse])
async def get_note_list(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
notes = [
NoteUserResponse(
**{
**note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
}
)
for note in Notes.get_notes_by_user_id(user.id, "read")
]
return notes
############################
# CreateNewNote
############################
@router.post("/create", response_model=Optional[NoteModel])
async def create_new_note(
request: Request, form_data: NoteForm, user=Depends(get_verified_user)
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
try:
note = Notes.insert_new_note(form_data, user.id)
return note
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetNoteById
############################
@router.get("/{id}", response_model=Optional[NoteModel])
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
return note
############################
# UpdateNoteById
############################
@router.post("/{id}/update", response_model=Optional[NoteModel])
async def update_note_by_id(
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user)
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
note = Notes.update_note_by_id(id, form_data)
return note
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# DeleteNoteById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
note = Notes.delete_note_by_id(id)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)

View file

@ -9,11 +9,18 @@ import os
import random import random
import re import re
import time import time
from datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
from aiocache import cached from aiocache import cached
import requests import requests
from open_webui.models.users import UserModel
from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from fastapi import ( from fastapi import (
Depends, Depends,
@ -26,7 +33,7 @@ from fastapi import (
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, validator
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
@ -49,8 +56,9 @@ from open_webui.config import (
from open_webui.env import ( from open_webui.env import (
ENV, ENV,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
BYPASS_MODEL_ACCESS_CONTROL, BYPASS_MODEL_ACCESS_CONTROL,
) )
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
@ -66,12 +74,27 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
########################################## ##########################################
async def send_get_request(url, key=None): async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try: try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get( async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} url,
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response: ) as response:
return await response.json() return await response.json()
except Exception as e: except Exception as e:
@ -96,6 +119,7 @@ async def send_post_request(
stream: bool = True, stream: bool = True,
key: Optional[str] = None, key: Optional[str] = None,
content_type: Optional[str] = None, content_type: Optional[str] = None,
user: UserModel = None,
): ):
r = None r = None
@ -110,7 +134,18 @@ async def send_post_request(
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
}, },
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) )
r.raise_for_status() r.raise_for_status()
@ -186,12 +221,26 @@ async def verify_connection(
key = form_data.key key = form_data.key
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session: ) as session:
try: try:
async with session.get( async with session.get(
f"{url}/api/version", f"{url}/api/version",
headers={**({"Authorization": f"Bearer {key}"} if key else {})}, headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
if r.status != 200: if r.status != 200:
detail = f"HTTP Error: {r.status}" detail = f"HTTP Error: {r.status}"
@ -253,8 +302,24 @@ async def update_config(
} }
@cached(ttl=3) def merge_ollama_models_lists(model_lists):
async def get_all_models(request: Request): merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()") log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = [] request_tasks = []
@ -262,7 +327,7 @@ async def get_all_models(request: Request):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
): ):
request_tasks.append(send_get_request(f"{url}/api/tags")) request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
else: else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx), str(idx),
@ -275,7 +340,9 @@ async def get_all_models(request: Request):
key = api_config.get("key", None) key = api_config.get("key", None)
if enable: if enable:
request_tasks.append(send_get_request(f"{url}/api/tags", key)) request_tasks.append(
send_get_request(f"{url}/api/tags", key, user=user)
)
else: else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
@ -291,7 +358,10 @@ async def get_all_models(request: Request):
), # Legacy support ), # Legacy support
) )
connection_type = api_config.get("connection_type", "local")
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
model_ids = api_config.get("model_ids", []) model_ids = api_config.get("model_ids", [])
if len(model_ids) != 0 and "models" in response: if len(model_ids) != 0 and "models" in response:
@ -302,27 +372,18 @@ async def get_all_models(request: Request):
) )
) )
if prefix_id:
for model in response.get("models", []): for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}" model["model"] = f"{prefix_id}.{model['model']}"
def merge_models_lists(model_lists): if tags:
merged_models = {} model["tags"] = tags
for idx, model_list in enumerate(model_lists): if connection_type:
if model_list is not None: model["connection_type"] = connection_type
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
models = { models = {
"models": merge_models_lists( "models": merge_ollama_models_lists(
map( map(
lambda response: response.get("models", []) if response else None, lambda response: response.get("models", []) if response else None,
responses, responses,
@ -330,6 +391,22 @@ async def get_all_models(request: Request):
) )
} }
try:
loaded_models = await get_ollama_loaded_models(request, user=user)
expires_map = {
m["name"]: m["expires_at"]
for m in loaded_models["models"]
if "expires_at" in m
}
for m in models["models"]:
if m["name"] in expires_map:
# Parse ISO8601 datetime with offset, get unix timestamp as int
dt = datetime.fromisoformat(expires_map[m["name"]])
m["expires_at"] = int(dt.timestamp())
except Exception as e:
log.debug(f"Failed to get loaded models: {e}")
else: else:
models = {"models": []} models = {"models": []}
@ -360,7 +437,7 @@ async def get_ollama_tags(
models = [] models = []
if url_idx is None: if url_idx is None:
models = await get_all_models(request) models = await get_all_models(request, user=user)
else: else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@ -370,7 +447,19 @@ async def get_ollama_tags(
r = requests.request( r = requests.request(
method="GET", method="GET",
url=f"{url}/api/tags", url=f"{url}/api/tags",
headers={**({"Authorization": f"Bearer {key}"} if key else {})}, headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) )
r.raise_for_status() r.raise_for_status()
@ -398,24 +487,95 @@ async def get_ollama_tags(
return models return models
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
request_tasks.append(
send_get_request(f"{url}/api/ps", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses):
if response:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None)
for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}"
models = {
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
)
)
}
else:
models = {"models": []}
return models
@router.get("/api/version") @router.get("/api/version")
@router.get("/api/version/{url_idx}") @router.get("/api/version/{url_idx}")
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
if request.app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
if url_idx is None: if url_idx is None:
# returns lowest version # returns lowest version
request_tasks = [ request_tasks = []
send_get_request(
f"{url}/api/version", for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
request.app.state.config.OLLAMA_API_CONFIGS.get( api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx), str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get( request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {} url, {}
), # Legacy support ), # Legacy support
).get("key", None),
) )
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
] enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
request_tasks.append(
send_get_request(
f"{url}/api/version",
key,
)
)
responses = await asyncio.gather(*request_tasks) responses = await asyncio.gather(*request_tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
@ -462,35 +622,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
return {"version": False} return {"version": False}
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = [
send_get_request(
f"{url}/api/ps",
request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*request_tasks)
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else:
return {}
class ModelNameForm(BaseModel): class ModelNameForm(BaseModel):
name: str name: str
@router.post("/api/unload")
async def unload_model(
request: Request,
form_data: ModelNameForm,
user=Depends(get_admin_user),
):
model_name = form_data.name
if not model_name:
raise HTTPException(
status_code=400, detail="Missing 'name' of model to unload."
)
# Refresh/load models if needed, get mapping from name to URLs
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
# Canonicalize model name (if not supplied with version)
if ":" not in model_name:
model_name = f"{model_name}:latest"
if model_name not in models:
raise HTTPException(
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
)
url_indices = models[model_name]["urls"]
# Send unload to ALL url_indices
results = []
errors = []
for idx in url_indices:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
)
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id and model_name.startswith(f"{prefix_id}."):
model_name = model_name[len(f"{prefix_id}.") :]
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
try:
res = await send_post_request(
url=f"{url}/api/generate",
payload=json.dumps(payload),
stream=False,
key=key,
user=user,
)
results.append({"url_idx": idx, "success": True, "response": res})
except Exception as e:
log.exception(f"Failed to unload model on node {idx}: {e}")
errors.append({"url_idx": idx, "success": False, "error": str(e)})
if len(errors) > 0:
raise HTTPException(
status_code=500,
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
)
return {"status": True}
@router.post("/api/pull") @router.post("/api/pull")
@router.post("/api/pull/{url_idx}") @router.post("/api/pull/{url_idx}")
async def pull_model( async def pull_model(
@ -509,6 +708,7 @@ async def pull_model(
url=f"{url}/api/pull", url=f"{url}/api/pull",
payload=json.dumps(payload), payload=json.dumps(payload),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
) )
@ -527,7 +727,7 @@ async def push_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
await get_all_models(request) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name in models: if form_data.name in models:
@ -545,6 +745,7 @@ async def push_model(
url=f"{url}/api/push", url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(), payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
) )
@ -571,6 +772,7 @@ async def create_model(
url=f"{url}/api/create", url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(), payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
) )
@ -588,7 +790,7 @@ async def copy_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
await get_all_models(request) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.source in models: if form_data.source in models:
@ -609,6 +811,16 @@ async def copy_model(
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
}, },
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
@ -643,7 +855,7 @@ async def delete_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
await get_all_models(request) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name in models: if form_data.name in models:
@ -665,6 +877,16 @@ async def delete_model(
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
}, },
) )
r.raise_for_status() r.raise_for_status()
@ -693,7 +915,7 @@ async def delete_model(
async def show_model_info( async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
): ):
await get_all_models(request) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name not in models: if form_data.name not in models:
@ -714,6 +936,16 @@ async def show_model_info(
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
}, },
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
@ -757,7 +989,7 @@ async def embed(
log.info(f"generate_ollama_batch_embeddings {form_data}") log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None: if url_idx is None:
await get_all_models(request) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
model = form_data.model model = form_data.model
@ -774,8 +1006,16 @@ async def embed(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(url_idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
try: try:
r = requests.request( r = requests.request(
method="POST", method="POST",
@ -783,6 +1023,16 @@ async def embed(
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
}, },
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
@ -826,7 +1076,7 @@ async def embeddings(
log.info(f"generate_ollama_embeddings {form_data}") log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None: if url_idx is None:
await get_all_models(request) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
model = form_data.model model = form_data.model
@ -843,8 +1093,16 @@ async def embeddings(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(url_idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
try: try:
r = requests.request( r = requests.request(
method="POST", method="POST",
@ -852,6 +1110,16 @@ async def embeddings(
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
}, },
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
@ -882,7 +1150,7 @@ class GenerateCompletionForm(BaseModel):
prompt: str prompt: str
suffix: Optional[str] = None suffix: Optional[str] = None
images: Optional[list[str]] = None images: Optional[list[str]] = None
format: Optional[str] = None format: Optional[Union[dict, str]] = None
options: Optional[dict] = None options: Optional[dict] = None
system: Optional[str] = None system: Optional[str] = None
template: Optional[str] = None template: Optional[str] = None
@ -901,7 +1169,7 @@ async def generate_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx is None: if url_idx is None:
await get_all_models(request) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
model = form_data.model model = form_data.model
@ -931,20 +1199,34 @@ async def generate_completion(
url=f"{url}/api/generate", url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(), payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
) )
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
content: str content: Optional[str] = None
tool_calls: Optional[list[dict]] = None tool_calls: Optional[list[dict]] = None
images: Optional[list[str]] = None images: Optional[list[str]] = None
@validator("content", pre=True)
@classmethod
def check_at_least_one_field(cls, field_value, values, **kwargs):
# Raise an error if both 'content' and 'tool_calls' are None
if field_value is None and (
"tool_calls" not in values or values["tool_calls"] is None
):
raise ValueError(
"At least one of 'content' or 'tool_calls' must be provided"
)
return field_value
class GenerateChatCompletionForm(BaseModel): class GenerateChatCompletionForm(BaseModel):
model: str model: str
messages: list[ChatMessage] messages: list[ChatMessage]
format: Optional[dict] = None format: Optional[Union[dict, str]] = None
options: Optional[dict] = None options: Optional[dict] = None
template: Optional[str] = None template: Optional[str] = None
stream: Optional[bool] = True stream: Optional[bool] = True
@ -1001,13 +1283,14 @@ async def generate_chat_completion(
params = model_info.params.model_dump() params = model_info.params.model_dump()
if params: if params:
if payload.get("options") is None: system = params.pop("system", None)
payload["options"] = {}
# Unlike OpenAI, Ollama does not support params directly in the body
payload["options"] = apply_model_params_to_body_ollama( payload["options"] = apply_model_params_to_body_ollama(
params, payload["options"] params, (payload.get("options", {}) or {})
) )
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model # Check if user has access to the model
if not bypass_filter and user.role == "user": if not bypass_filter and user.role == "user":
@ -1040,13 +1323,14 @@ async def generate_chat_completion(
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
# payload["keep_alive"] = -1 # keep alive forever
return await send_post_request( return await send_post_request(
url=f"{url}/api/chat", url=f"{url}/api/chat",
payload=json.dumps(payload), payload=json.dumps(payload),
stream=form_data.stream, stream=form_data.stream,
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson", content_type="application/x-ndjson",
user=user,
) )
@ -1058,7 +1342,7 @@ class OpenAIChatMessageContent(BaseModel):
class OpenAIChatMessage(BaseModel): class OpenAIChatMessage(BaseModel):
role: str role: str
content: Union[str, list[OpenAIChatMessageContent]] content: Union[Optional[str], list[OpenAIChatMessageContent]]
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@ -1149,6 +1433,7 @@ async def generate_openai_completion(
payload=json.dumps(payload), payload=json.dumps(payload),
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
) )
@ -1187,8 +1472,10 @@ async def generate_openai_chat_completion(
params = model_info.params.model_dump() params = model_info.params.model_dump()
if params: if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user) payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model # Check if user has access to the model
if user.role == "user": if user.role == "user":
@ -1227,6 +1514,7 @@ async def generate_openai_chat_completion(
payload=json.dumps(payload), payload=json.dumps(payload),
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
) )
@ -1240,7 +1528,7 @@ async def get_openai_models(
models = [] models = []
if url_idx is None: if url_idx is None:
model_list = await get_all_models(request) model_list = await get_all_models(request, user=user)
models = [ models = [
{ {
"id": model["model"], "id": model["model"],
@ -1341,7 +1629,9 @@ async def download_file_stream(
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(file_url, headers=headers) as response: async with session.get(
file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file: with open(file_path, "ab+") as file:
@ -1356,7 +1646,8 @@ async def download_file_stream(
if done: if done:
file.seek(0) file.seek(0)
hashed = calculate_sha256(file) chunk_size = 1024 * 1024 * 2
hashed = calculate_sha256(file, chunk_size)
file.seek(0) file.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}" url = f"{ollama_url}/api/blobs/sha256:{hashed}"
@ -1420,7 +1711,9 @@ async def upload_model(
if url_idx is None: if url_idx is None:
url_idx = 0 url_idx = 0
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = os.path.join(UPLOAD_DIR, file.filename)
filename = os.path.basename(file.filename)
file_path = os.path.join(UPLOAD_DIR, filename)
os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(UPLOAD_DIR, exist_ok=True)
# --- P1: save file locally --- # --- P1: save file locally ---
@ -1465,13 +1758,13 @@ async def upload_model(
os.remove(file_path) os.remove(file_path)
# Create model in ollama # Create model in ollama
model_name, ext = os.path.splitext(file.filename) model_name, ext = os.path.splitext(filename)
log.info(f"Created Model: {model_name}") # DEBUG log.info(f"Created Model: {model_name}") # DEBUG
create_payload = { create_payload = {
"model": model_name, "model": model_name,
# Reference the file by its original name => the uploaded blob's digest # Reference the file by its original name => the uploaded blob's digest
"files": {file.filename: f"sha256:{file_hash}"}, "files": {filename: f"sha256:{file_hash}"},
} }
log.info(f"Model Payload: {create_payload}") # DEBUG log.info(f"Model Payload: {create_payload}") # DEBUG
@ -1488,7 +1781,7 @@ async def upload_model(
done_msg = { done_msg = {
"done": True, "done": True,
"blob": f"sha256:{file_hash}", "blob": f"sha256:{file_hash}",
"name": file.filename, "name": filename,
"model_created": model_name, "model_created": model_name,
} }
yield f"data: {json.dumps(done_msg)}\n\n" yield f"data: {json.dumps(done_msg)}\n\n"

View file

@ -21,11 +21,13 @@ from open_webui.config import (
CACHE_DIR, CACHE_DIR,
) )
from open_webui.env import ( from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
ENABLE_FORWARD_USER_INFO_HEADERS, ENABLE_FORWARD_USER_INFO_HEADERS,
BYPASS_MODEL_ACCESS_CONTROL, BYPASS_MODEL_ACCESS_CONTROL,
) )
from open_webui.models.users import UserModel
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS from open_webui.env import ENV, SRC_LOG_LEVELS
@ -35,6 +37,9 @@ from open_webui.utils.payload import (
apply_model_params_to_body_openai, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
from open_webui.utils.misc import (
convert_logit_bias_input_to_json,
)
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
@ -51,12 +56,26 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
########################################## ##########################################
async def send_get_request(url, key=None): async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try: try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get( async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} url,
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response: ) as response:
return await response.json() return await response.json()
except Exception as e: except Exception as e:
@ -75,18 +94,23 @@ async def cleanup_response(
await session.close() await session.close()
def openai_o1_o3_handler(payload): def openai_o_series_handler(payload):
""" """
Handle o1, o3 specific parameters Handle "o" series specific parameters
""" """
if "max_tokens" in payload: if "max_tokens" in payload:
# Remove "max_tokens" from the payload # Convert "max_tokens" to "max_completion_tokens" for all o-series models
payload["max_completion_tokens"] = payload["max_tokens"] payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"] del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user" # Handle system role conversion based on model type
if payload["messages"][0]["role"] == "system": if payload["messages"][0]["role"] == "system":
model_lower = payload["model"].lower()
# Legacy models use "user" role instead of "system"
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
payload["messages"][0]["role"] = "user" payload["messages"][0]["role"] = "user"
else:
payload["messages"][0]["role"] = "developer"
return payload return payload
@ -172,7 +196,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
@ -247,7 +271,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def get_all_models_responses(request: Request) -> list: async def get_all_models_responses(request: Request, user: UserModel) -> list:
if not request.app.state.config.ENABLE_OPENAI_API: if not request.app.state.config.ENABLE_OPENAI_API:
return [] return []
@ -271,7 +295,9 @@ async def get_all_models_responses(request: Request) -> list:
): ):
request_tasks.append( request_tasks.append(
send_get_request( send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
) )
) )
else: else:
@ -291,6 +317,7 @@ async def get_all_models_responses(request: Request) -> list:
send_get_request( send_get_request(
f"{url}/models", f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx], request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
) )
) )
else: else:
@ -326,14 +353,22 @@ async def get_all_models_responses(request: Request) -> list:
), # Legacy support ), # Legacy support
) )
connection_type = api_config.get("connection_type", "external")
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
if prefix_id:
for model in ( for model in (
response if isinstance(response, list) else response.get("data", []) response if isinstance(response, list) else response.get("data", [])
): ):
if prefix_id:
model["id"] = f"{prefix_id}.{model['id']}" model["id"] = f"{prefix_id}.{model['id']}"
if tags:
model["tags"] = tags
if connection_type:
model["connection_type"] = connection_type
log.debug(f"get_all_models:responses() {responses}") log.debug(f"get_all_models:responses() {responses}")
return responses return responses
@ -351,14 +386,14 @@ async def get_filtered_models(models, user):
return filtered_models return filtered_models
@cached(ttl=3) @cached(ttl=1)
async def get_all_models(request: Request) -> dict[str, list]: async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()") log.info("get_all_models()")
if not request.app.state.config.ENABLE_OPENAI_API: if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []} return {"data": []}
responses = await get_all_models_responses(request) responses = await get_all_models_responses(request, user=user)
def extract_data(response): def extract_data(response):
if response and "data" in response: if response and "data" in response:
@ -373,6 +408,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
for idx, models in enumerate(model_lists): for idx, models in enumerate(model_lists):
if models is not None and "error" not in models: if models is not None and "error" not in models:
merged_list.extend( merged_list.extend(
[ [
{ {
@ -380,10 +416,13 @@ async def get_all_models(request: Request) -> dict[str, list]:
"name": model.get("name", model["id"]), "name": model.get("name", model["id"]),
"owned_by": "openai", "owned_by": "openai",
"openai": model, "openai": model,
"connection_type": model.get("connection_type", "external"),
"urlIdx": idx, "urlIdx": idx,
} }
for model in models for model in models
if "api.openai.com" if (model.get("id") or model.get("name"))
and (
"api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx] not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any( or not any(
name in model["id"] name in model["id"]
@ -396,6 +435,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
"whisper", "whisper",
] ]
) )
)
] ]
) )
@ -418,22 +458,23 @@ async def get_models(
} }
if url_idx is None: if url_idx is None:
models = await get_all_models(request) models = await get_all_models(request, user=user)
else: else:
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(url_idx),
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
)
r = None r = None
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout( trust_env=True,
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
)
) as session: ) as session:
try: try:
async with session.get(
f"{url}/models",
headers = { headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{ {
@ -445,7 +486,20 @@ async def get_models(
if ENABLE_FORWARD_USER_INFO_HEADERS if ENABLE_FORWARD_USER_INFO_HEADERS
else {} else {}
), ),
}, }
if api_config.get("azure", False):
models = {
"data": api_config.get("model_ids", []) or [],
"object": "list",
}
else:
headers["Authorization"] = f"Bearer {key}"
async with session.get(
f"{url}/models",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: ) as r:
if r.status != 200: if r.status != 200:
# Extract response error details if available # Extract response error details if available
@ -498,6 +552,8 @@ class ConnectionVerificationForm(BaseModel):
url: str url: str
key: str key: str
config: Optional[dict] = None
@router.post("/verify") @router.post("/verify")
async def verify_connection( async def verify_connection(
@ -506,16 +562,53 @@ async def verify_connection(
url = form_data.url url = form_data.url
key = form_data.key key = form_data.key
api_config = form_data.config or {}
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session: ) as session:
try: try:
headers = {
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False):
headers["api-key"] = key
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
async with session.get(
url=f"{url}/openai/models?api-version={api_version}",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = await r.json()
return response_data
else:
headers["Authorization"] = f"Bearer {key}"
async with session.get( async with session.get(
f"{url}/models", f"{url}/models",
headers={ headers=headers,
"Authorization": f"Bearer {key}", ssl=AIOHTTP_CLIENT_SESSION_SSL,
"Content-Type": "application/json",
},
) as r: ) as r:
if r.status != 200: if r.status != 200:
# Extract response error details if available # Extract response error details if available
@ -540,6 +633,63 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail) raise HTTPException(status_code=500, detail=error_detail)
def convert_to_azure_payload(
url,
payload: dict,
):
model = payload.get("model", "")
# Filter allowed parameters based on Azure OpenAI API
allowed_params = {
"messages",
"temperature",
"role",
"content",
"contentPart",
"contentPartImage",
"enhancements",
"dataSources",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"functions",
"tools",
"tool_choice",
"top_p",
"log_probs",
"top_logprobs",
"response_format",
"seed",
"max_completion_tokens",
}
# Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"):
# Convert max_tokens to max_completion_tokens for o-series models
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Remove temperature if not 1 for o-series models
if "temperature" in payload and payload["temperature"] != 1:
log.debug(
f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
)
del payload["temperature"]
# Filter out unsupported parameters
payload = {k: v for k, v in payload.items() if k in allowed_params}
url = f"{url}/openai/deployments/{model}"
return url, payload
@router.post("/chat/completions") @router.post("/chat/completions")
async def generate_chat_completion( async def generate_chat_completion(
request: Request, request: Request,
@ -565,8 +715,12 @@ async def generate_chat_completion(
model_id = model_info.base_model_id model_id = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user) payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model # Check if user has access to the model
if not bypass_filter and user.role == "user": if not bypass_filter and user.role == "user":
@ -587,7 +741,7 @@ async def generate_chat_completion(
detail="Model not found", detail="Model not found",
) )
await get_all_models(request) await get_all_models(request, user=user)
model = request.app.state.OPENAI_MODELS.get(model_id) model = request.app.state.OPENAI_MODELS.get(model_id)
if model: if model:
idx = model["urlIdx"] idx = model["urlIdx"]
@ -621,10 +775,10 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" # Check if model is from "o" series
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-")) is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
if is_o1_o3: if is_o_series:
payload = openai_o1_o3_handler(payload) payload = openai_o_series_handler(payload)
elif "api.openai.com" not in url: elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility # Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload: if "max_completion_tokens" in payload:
@ -635,24 +789,12 @@ async def generate_chat_completion(
del payload["max_tokens"] del payload["max_tokens"]
# Convert the modified body back to JSON # Convert the modified body back to JSON
payload = json.dumps(payload) if "logit_bias" in payload:
payload["logit_bias"] = json.loads(
r = None convert_logit_bias_input_to_json(payload["logit_bias"])
session = None
streaming = False
response = None
try:
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) )
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers = { headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{ {
@ -672,7 +814,36 @@ async def generate_chat_completion(
if ENABLE_FORWARD_USER_INFO_HEADERS if ENABLE_FORWARD_USER_INFO_HEADERS
else {} else {}
), ),
}, }
if api_config.get("azure", False):
request_url, payload = convert_to_azure_payload(url, payload)
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
headers["api-key"] = key
headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}"
else:
request_url = f"{url}/chat/completions"
headers["Authorization"] = f"Bearer {key}"
payload = json.dumps(payload)
r = None
session = None
streaming = False
response = None
try:
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request(
method="POST",
url=request_url,
data=payload,
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) )
# Check if response is SSE # Check if response is SSE
@ -801,19 +972,19 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0 idx = 0
url = request.app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(idx),
request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
), # Legacy support
)
r = None r = None
session = None session = None
streaming = False streaming = False
try: try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=f"{url}/{path}",
data=body,
headers = { headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{ {
@ -825,7 +996,30 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if ENABLE_FORWARD_USER_INFO_HEADERS if ENABLE_FORWARD_USER_INFO_HEADERS
else {} else {}
), ),
}, }
if api_config.get("azure", False):
headers["api-key"] = key
headers["api-version"] = (
api_config.get("api_version", "") or "2023-03-15-preview"
)
payload = json.loads(body)
url, payload = convert_to_azure_payload(url, payload)
body = json.dumps(payload).encode()
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
else:
headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}"
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=request_url,
data=body,
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) )
r.raise_for_status() r.raise_for_status()
@ -851,7 +1045,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if r is not None: if r is not None:
try: try:
res = await r.json() res = await r.json()
print(res) log.error(res)
if "error" in res: if "error" in res:
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception: except Exception:

View file

@ -9,6 +9,7 @@ from fastapi import (
status, status,
APIRouter, APIRouter,
) )
import aiohttp
import os import os
import logging import logging
import shutil import shutil
@ -17,7 +18,7 @@ from pydantic import BaseModel
from starlette.responses import FileResponse from starlette.responses import FileResponse
from typing import Optional from typing import Optional
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
@ -56,96 +57,111 @@ def get_sorted_filters(model_id, models):
return sorted_filters return sorted_filters
def process_pipeline_inlet_filter(request, payload, user, models): async def process_pipeline_inlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"] model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models) sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id] model = models[model_id]
if "pipeline" in model: if "pipeline" in model:
sorted_filters.append(model) sorted_filters.append(model)
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters: for filter in sorted_filters:
r = None urlIdx = filter.get("urlIdx")
try: try:
urlIdx = filter["urlIdx"] urlIdx = int(urlIdx)
except:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "": if not key:
continue continue
headers = {"Authorization": f"Bearer {key}"} headers = {"Authorization": f"Bearer {key}"}
r = requests.post( request_data = {
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user, "user": user,
"body": payload, "body": payload,
}, }
try:
async with session.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json=request_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
payload = await response.json()
response.raise_for_status()
except aiohttp.ClientResponseError as e:
res = (
await response.json()
if response.content_type == "application/json"
else {}
) )
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
if "detail" in res: if "detail" in res:
raise Exception(r.status_code, res["detail"]) raise Exception(response.status, res["detail"])
except Exception as e:
log.exception(f"Connection error: {e}")
return payload return payload
def process_pipeline_outlet_filter(request, payload, user, models): async def process_pipeline_outlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"] model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models) sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id] model = models[model_id]
if "pipeline" in model: if "pipeline" in model:
sorted_filters = [model] + sorted_filters sorted_filters = [model] + sorted_filters
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters: for filter in sorted_filters:
r = None urlIdx = filter.get("urlIdx")
try: try:
urlIdx = filter["urlIdx"] urlIdx = int(urlIdx)
except:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "": if not key:
r = requests.post( continue
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"}, headers = {"Authorization": f"Bearer {key}"}
json={ request_data = {
"user": user, "user": user,
"body": payload, "body": payload,
}, }
)
r.raise_for_status()
data = r.json()
payload = data
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try: try:
res = r.json() async with session.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json=request_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
payload = await response.json()
response.raise_for_status()
except aiohttp.ClientResponseError as e:
try:
res = (
await response.json()
if "application/json" in response.content_type
else {}
)
if "detail" in res: if "detail" in res:
return Exception(r.status_code, res) raise Exception(response.status, res)
except Exception: except Exception:
pass pass
except Exception as e:
else: log.exception(f"Connection error: {e}")
pass
return payload return payload
@ -161,7 +177,7 @@ router = APIRouter()
@router.get("/list") @router.get("/list")
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
responses = await get_all_models_responses(request) responses = await get_all_models_responses(request, user)
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
urlIdxs = [ urlIdxs = [
@ -188,9 +204,11 @@ async def upload_pipeline(
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
print("upload_pipeline", urlIdx, file.filename) log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
filename = os.path.basename(file.filename)
# Check if the uploaded file is a python file # Check if the uploaded file is a python file
if not (file.filename and file.filename.endswith(".py")): if not (filename and filename.endswith(".py")):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.", detail="Only Python (.py) files are allowed.",
@ -198,7 +216,7 @@ async def upload_pipeline(
upload_folder = f"{CACHE_DIR}/pipelines" upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True) os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename) file_path = os.path.join(upload_folder, filename)
r = None r = None
try: try:
@ -223,7 +241,7 @@ async def upload_pipeline(
return {**data} return {**data}
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.exception(f"Connection error: {e}")
detail = None detail = None
status_code = status.HTTP_404_NOT_FOUND status_code = status.HTTP_404_NOT_FOUND
@ -274,7 +292,7 @@ async def add_pipeline(
return {**data} return {**data}
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.exception(f"Connection error: {e}")
detail = None detail = None
if r is not None: if r is not None:
@ -319,7 +337,7 @@ async def delete_pipeline(
return {**data} return {**data}
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.exception(f"Connection error: {e}")
detail = None detail = None
if r is not None: if r is not None:
@ -353,7 +371,7 @@ async def get_pipelines(
return {**data} return {**data}
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.exception(f"Connection error: {e}")
detail = None detail = None
if r is not None: if r is not None:
@ -392,7 +410,7 @@ async def get_pipeline_valves(
return {**data} return {**data}
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.exception(f"Connection error: {e}")
detail = None detail = None
if r is not None: if r is not None:
@ -432,7 +450,7 @@ async def get_pipeline_valves_spec(
return {**data} return {**data}
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.exception(f"Connection error: {e}")
detail = None detail = None
if r is not None: if r is not None:
@ -474,7 +492,7 @@ async def update_pipeline_valves(
return {**data} return {**data}
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.exception(f"Connection error: {e}")
detail = None detail = None

File diff suppressed because it is too large Load diff

View file

@ -20,6 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.task import get_task_model_id from open_webui.utils.task import get_task_model_id
from open_webui.config import ( from open_webui.config import (
@ -182,35 +183,28 @@ async def generate_title(
else: else:
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
messages = form_data["messages"]
# Remove reasoning details from the messages
for message in messages:
message["content"] = re.sub(
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
"",
message["content"],
flags=re.S,
).strip()
content = title_generation_template( content = title_generation_template(
template, template,
messages, form_data["messages"],
{ {
"name": user.name, "name": user.name,
"location": user.info.get("location") if user.info else None, "location": user.info.get("location") if user.info else None,
}, },
) )
max_tokens = (
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
)
payload = { payload = {
"model": task_model_id, "model": task_model_id,
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
**( **(
{"max_tokens": 1000} {"max_tokens": max_tokens}
if models[task_model_id]["owned_by"] == "ollama" if models[task_model_id].get("owned_by") == "ollama"
else { else {
"max_completion_tokens": 1000, "max_completion_tokens": max_tokens,
} }
), ),
"metadata": { "metadata": {
@ -221,6 +215,12 @@ async def generate_title(
}, },
} }
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try: try:
return await generate_chat_completion(request, form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e: except Exception as e:
@ -290,6 +290,12 @@ async def generate_chat_tags(
}, },
} }
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try: try:
return await generate_chat_completion(request, form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e: except Exception as e:
@ -356,6 +362,12 @@ async def generate_image_prompt(
}, },
} }
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try: try:
return await generate_chat_completion(request, form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e: except Exception as e:
@ -433,6 +445,12 @@ async def generate_queries(
}, },
} }
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try: try:
return await generate_chat_completion(request, form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e: except Exception as e:
@ -514,6 +532,12 @@ async def generate_autocompletion(
}, },
} }
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try: try:
return await generate_chat_completion(request, form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e: except Exception as e:
@ -571,7 +595,7 @@ async def generate_emoji(
"stream": False, "stream": False,
**( **(
{"max_tokens": 4} {"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama" if models[task_model_id].get("owned_by") == "ollama"
else { else {
"max_completion_tokens": 4, "max_completion_tokens": 4,
} }
@ -584,6 +608,12 @@ async def generate_emoji(
}, },
} }
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try: try:
return await generate_chat_completion(request, form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e: except Exception as e:
@ -613,17 +643,6 @@ async def generate_moa_response(
detail="Model not found", detail="Model not found",
) )
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
content = moa_response_generation_template( content = moa_response_generation_template(
@ -633,7 +652,7 @@ async def generate_moa_response(
) )
payload = { payload = {
"model": task_model_id, "model": model_id,
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False), "stream": form_data.get("stream", False),
"metadata": { "metadata": {
@ -644,6 +663,12 @@ async def generate_moa_response(
}, },
} }
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try: try:
return await generate_chat_completion(request, form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e: except Exception as e:

View file

@ -1,5 +1,10 @@
import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import time
import re
import aiohttp
from pydantic import BaseModel, HttpUrl
from open_webui.models.tools import ( from open_webui.models.tools import (
ToolForm, ToolForm,
@ -8,13 +13,20 @@ from open_webui.models.tools import (
ToolUserResponse, ToolUserResponse,
Tools, Tools,
) )
from open_webui.utils.plugin import load_tools_module_by_id, replace_imports from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs from open_webui.utils.tools import get_tool_specs
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.tools import get_tool_servers_data
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter() router = APIRouter()
@ -25,11 +37,51 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse]) @router.get("/", response_model=list[ToolUserResponse])
async def get_tools(user=Depends(get_verified_user)): async def get_tools(request: Request, user=Depends(get_verified_user)):
if user.role == "admin":
if not request.app.state.TOOL_SERVERS:
# If the tool servers are not set, we need to set them
# This is done only once when the server starts
# This is done to avoid loading the tool servers every time
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
tools = Tools.get_tools() tools = Tools.get_tools()
else: for server in request.app.state.TOOL_SERVERS:
tools = Tools.get_tools_by_user_id(user.id, "read") tools.append(
ToolUserResponse(
**{
"id": f"server:{server['idx']}",
"user_id": f"server:{server['idx']}",
"name": server.get("openapi", {})
.get("info", {})
.get("title", "Tool Server"),
"meta": {
"description": server.get("openapi", {})
.get("info", {})
.get("description", ""),
},
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
server["idx"]
]
.get("config", {})
.get("access_control", None),
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
)
if user.role != "admin":
tools = [
tool
for tool in tools
if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control)
]
return tools return tools
@ -47,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
return tools return tools
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_tool_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
tool_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the tool"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": tool_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
############################ ############################
# ExportTools # ExportTools
############################ ############################
@ -89,18 +216,18 @@ async def create_new_tools(
if tools is None: if tools is None:
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
tools_module, frontmatter = load_tools_module_by_id( tool_module, frontmatter = load_tool_module_by_id(
form_data.id, content=form_data.content form_data.id, content=form_data.content
) )
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = tools_module TOOLS[form_data.id] = tool_module
specs = get_tools_specs(TOOLS[form_data.id]) specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs) tools = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True) tool_cache_dir.mkdir(parents=True, exist_ok=True)
if tools: if tools:
@ -111,7 +238,7 @@ async def create_new_tools(
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
) )
except Exception as e: except Exception as e:
print(e) log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)), detail=ERROR_MESSAGES.DEFAULT(str(e)),
@ -178,22 +305,20 @@ async def update_tools_by_id(
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
tools_module, frontmatter = load_tools_module_by_id( tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
id, content=form_data.content
)
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[id] = tools_module TOOLS[id] = tool_module
specs = get_tools_specs(TOOLS[id]) specs = get_tool_specs(TOOLS[id])
updated = { updated = {
**form_data.model_dump(exclude={"id"}), **form_data.model_dump(exclude={"id"}),
"specs": specs, "specs": specs,
} }
print(updated) log.debug(updated)
tools = Tools.update_tool_by_id(id, updated) tools = Tools.update_tool_by_id(id, updated)
if tools: if tools:
@ -284,7 +409,7 @@ async def get_tools_valves_spec_by_id(
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tools_module_by_id(id) tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "Valves"): if hasattr(tools_module, "Valves"):
@ -327,7 +452,7 @@ async def update_tools_valves_by_id(
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tools_module_by_id(id) tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if not hasattr(tools_module, "Valves"): if not hasattr(tools_module, "Valves"):
@ -343,7 +468,7 @@ async def update_tools_valves_by_id(
Tools.update_tool_valves_by_id(id, valves.model_dump()) Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump() return valves.model_dump()
except Exception as e: except Exception as e:
print(e) log.exception(f"Failed to update tool valves by id {id}: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)), detail=ERROR_MESSAGES.DEFAULT(str(e)),
@ -383,7 +508,7 @@ async def get_tools_user_valves_spec_by_id(
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tools_module_by_id(id) tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"): if hasattr(tools_module, "UserValves"):
@ -407,7 +532,7 @@ async def update_tools_user_valves_by_id(
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tools_module_by_id(id) tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"): if hasattr(tools_module, "UserValves"):
@ -421,7 +546,7 @@ async def update_tools_user_valves_by_id(
) )
return user_valves.model_dump() return user_valves.model_dump()
except Exception as e: except Exception as e:
print(e) log.exception(f"Failed to update user valves by id {id}: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)), detail=ERROR_MESSAGES.DEFAULT(str(e)),

View file

@ -2,9 +2,11 @@ import logging
from typing import Optional from typing import Optional
from open_webui.models.auths import Auths from open_webui.models.auths import Auths
from open_webui.models.groups import Groups
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.users import ( from open_webui.models.users import (
UserModel, UserModel,
UserListResponse,
UserRoleUpdateForm, UserRoleUpdateForm,
Users, Users,
UserSettings, UserSettings,
@ -17,7 +19,10 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
from open_webui.utils.access_control import get_permissions, has_permission
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -29,13 +34,38 @@ router = APIRouter()
############################ ############################
@router.get("/", response_model=list[UserModel]) PAGE_ITEM_COUNT = 30
@router.get("/", response_model=UserListResponse)
async def get_users( async def get_users(
skip: Optional[int] = None, query: Optional[str] = None,
limit: Optional[int] = None, order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
return Users.get_users(skip, limit) limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit)
@router.get("/all", response_model=UserListResponse)
async def get_all_users(
user=Depends(get_admin_user),
):
return Users.get_users()
############################ ############################
@ -45,7 +75,7 @@ async def get_users(
@router.get("/groups") @router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)): async def get_user_groups(user=Depends(get_verified_user)):
return Users.get_user_groups(user.id) return Groups.get_groups_by_member_id(user.id)
############################ ############################
@ -54,8 +84,12 @@ async def get_user_groups(user=Depends(get_verified_user)):
@router.get("/permissions") @router.get("/permissions")
async def get_user_permissisions(user=Depends(get_verified_user)): async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
return Users.get_user_groups(user.id) user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return user_permissions
############################ ############################
@ -68,32 +102,52 @@ class WorkspacePermissions(BaseModel):
tools: bool = False tools: bool = False
class SharingPermissions(BaseModel):
public_models: bool = True
public_knowledge: bool = True
public_prompts: bool = True
public_tools: bool = True
class ChatPermissions(BaseModel): class ChatPermissions(BaseModel):
controls: bool = True controls: bool = True
file_upload: bool = True file_upload: bool = True
delete: bool = True delete: bool = True
edit: bool = True edit: bool = True
share: bool = True
export: bool = True
stt: bool = True
tts: bool = True
call: bool = True
multiple_models: bool = True
temporary: bool = True temporary: bool = True
temporary_enforced: bool = False
class FeaturesPermissions(BaseModel): class FeaturesPermissions(BaseModel):
direct_tool_servers: bool = False
web_search: bool = True web_search: bool = True
image_generation: bool = True image_generation: bool = True
code_interpreter: bool = True code_interpreter: bool = True
notes: bool = True
class UserPermissions(BaseModel): class UserPermissions(BaseModel):
workspace: WorkspacePermissions workspace: WorkspacePermissions
sharing: SharingPermissions
chat: ChatPermissions chat: ChatPermissions
features: FeaturesPermissions features: FeaturesPermissions
@router.get("/default/permissions", response_model=UserPermissions) @router.get("/default/permissions", response_model=UserPermissions)
async def get_user_permissions(request: Request, user=Depends(get_admin_user)): async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)):
return { return {
"workspace": WorkspacePermissions( "workspace": WorkspacePermissions(
**request.app.state.config.USER_PERMISSIONS.get("workspace", {}) **request.app.state.config.USER_PERMISSIONS.get("workspace", {})
), ),
"sharing": SharingPermissions(
**request.app.state.config.USER_PERMISSIONS.get("sharing", {})
),
"chat": ChatPermissions( "chat": ChatPermissions(
**request.app.state.config.USER_PERMISSIONS.get("chat", {}) **request.app.state.config.USER_PERMISSIONS.get("chat", {})
), ),
@ -104,7 +158,7 @@ async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
@router.post("/default/permissions") @router.post("/default/permissions")
async def update_user_permissions( async def update_default_user_permissions(
request: Request, form_data: UserPermissions, user=Depends(get_admin_user) request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
): ):
request.app.state.config.USER_PERMISSIONS = form_data.model_dump() request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
@ -151,9 +205,22 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/settings/update", response_model=UserSettings) @router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user( async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user) request: Request, form_data: UserSettings, user=Depends(get_verified_user)
): ):
user = Users.update_user_settings_by_id(user.id, form_data.model_dump()) updated_user_settings = form_data.model_dump()
if (
user.role != "admin"
and "toolServers" in updated_user_settings.get("ui").keys()
and not has_permission(
user.id,
"features.direct_tool_servers",
request.app.state.config.USER_PERMISSIONS,
)
):
# If the user is not an admin and does not have permission to use tool servers, remove the key
updated_user_settings["ui"].pop("toolServers", None)
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
if user: if user:
return user.settings return user.settings
else: else:
@ -263,6 +330,21 @@ async def update_user_by_id(
form_data: UserUpdateForm, form_data: UserUpdateForm,
session_user=Depends(get_admin_user), session_user=Depends(get_admin_user),
): ):
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id and session_user.id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not verify primary admin status.",
)
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
if user: if user:
@ -310,6 +392,21 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
# Prevent deletion of the primary admin user
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not verify primary admin status.",
)
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(user_id) result = Auths.delete_auth_by_id(user_id)
@ -321,6 +418,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
detail=ERROR_MESSAGES.DELETE_USER_ERROR, detail=ERROR_MESSAGES.DELETE_USER_ERROR,
) )
# Prevent self-deletion
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,

View file

@ -1,48 +1,84 @@
import black import black
import logging
import markdown import markdown
from open_webui.models.chats import ChatTitleMessagesForm from open_webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel from pydantic import BaseModel
from starlette.responses import FileResponse from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.auth import get_admin_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter() router = APIRouter()
@router.get("/gravatar") @router.get("/gravatar")
async def get_gravatar( async def get_gravatar(email: str, user=Depends(get_verified_user)):
email: str,
):
return get_gravatar_url(email) return get_gravatar_url(email)
class CodeFormatRequest(BaseModel): class CodeForm(BaseModel):
code: str code: str
@router.post("/code/format") @router.post("/code/format")
async def format_code(request: CodeFormatRequest): async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
try: try:
formatted_code = black.format_str(request.code, mode=black.Mode()) formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code} return {"code": formatted_code}
except black.NothingChanged: except black.NothingChanged:
return {"code": request.code} return {"code": form_data.code}
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@router.post("/code/execute")
async def execute_code(
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
):
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
output = await execute_code_jupyter(
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
form_data.code,
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
else None
),
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
else None
),
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
)
return output
else:
raise HTTPException(
status_code=400,
detail="Code execution engine not supported",
)
class MarkdownForm(BaseModel): class MarkdownForm(BaseModel):
md: str md: str
@router.post("/markdown") @router.post("/markdown")
async def get_html_from_markdown( async def get_html_from_markdown(
form_data: MarkdownForm, form_data: MarkdownForm, user=Depends(get_verified_user)
): ):
return {"html": markdown.markdown(form_data.md)} return {"html": markdown.markdown(form_data.md)}
@ -54,7 +90,7 @@ class ChatForm(BaseModel):
@router.post("/pdf") @router.post("/pdf")
async def download_chat_as_pdf( async def download_chat_as_pdf(
form_data: ChatTitleMessagesForm, form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
): ):
try: try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf() pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
@ -65,7 +101,7 @@ async def download_chat_as_pdf(
headers={"Content-Disposition": "attachment;filename=chat.pdf"}, headers={"Content-Disposition": "attachment;filename=chat.pdf"},
) )
except Exception as e: except Exception as e:
print(e) log.exception(f"Error generating PDF: {e}")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))

View file

@ -3,15 +3,23 @@ import socketio
import logging import logging
import sys import sys
import time import time
from redis import asyncio as aioredis
from open_webui.models.users import Users, UserNameResponse from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels from open_webui.models.channels import Channels
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.utils.redis import (
get_sentinels_from_env,
get_sentinel_url_from_env,
)
from open_webui.env import ( from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT, ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER, WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL, WEBSOCKET_REDIS_URL,
WEBSOCKET_REDIS_LOCK_TIMEOUT,
WEBSOCKET_SENTINEL_PORT,
WEBSOCKET_SENTINEL_HOSTS,
) )
from open_webui.utils.auth import decode_token from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock from open_webui.socket.utils import RedisDict, RedisLock
@ -28,6 +36,13 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS:
mgr = socketio.AsyncRedisManager(
get_sentinel_url_from_env(
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
)
else:
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL) mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
sio = socketio.AsyncServer( sio = socketio.AsyncServer(
cors_allowed_origins=[], cors_allowed_origins=[],
@ -54,14 +69,30 @@ TIMEOUT_DURATION = 3
if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.") log.debug("Using Redis to manage websockets.")
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) redis_sentinels = get_sentinels_from_env(
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) )
SESSION_POOL = RedisDict(
"open-webui:session_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
USER_POOL = RedisDict(
"open-webui:user_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
USAGE_POOL = RedisDict(
"open-webui:usage_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
clean_up_lock = RedisLock( clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock", lock_name="usage_cleanup_lock",
timeout_secs=TIMEOUT_DURATION * 2, timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
redis_sentinels=redis_sentinels,
) )
aquire_func = clean_up_lock.aquire_lock aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock renew_func = clean_up_lock.renew_lock
@ -128,6 +159,7 @@ def get_models_in_use():
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
if sid in SESSION_POOL:
model_id = data["model"] model_id = data["model"]
# Record the timestamp for the last update # Record the timestamp for the last update
current_time = int(time.time()) current_time = int(time.time())
@ -247,6 +279,7 @@ async def channel_events(sid, data):
@sio.on("user-list") @sio.on("user-list")
async def user_list(sid): async def user_list(sid):
if sid in SESSION_POOL:
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())}) await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
@ -268,15 +301,23 @@ async def disconnect(sid):
# print(f"Unknown session ID {sid} disconnected") # print(f"Unknown session ID {sid} disconnected")
def get_event_emitter(request_info): def get_event_emitter(request_info, update_db=True):
async def __event_emitter__(event_data): async def __event_emitter__(event_data):
user_id = request_info["user_id"] user_id = request_info["user_id"]
session_ids = list( session_ids = list(
set(USER_POOL.get(user_id, []) + [request_info["session_id"]]) set(
USER_POOL.get(user_id, [])
+ (
[request_info.get("session_id")]
if request_info.get("session_id")
else []
)
)
) )
for session_id in session_ids: emit_tasks = [
await sio.emit( sio.emit(
"chat-events", "chat-events",
{ {
"chat_id": request_info.get("chat_id", None), "chat_id": request_info.get("chat_id", None),
@ -285,7 +326,12 @@ def get_event_emitter(request_info):
}, },
to=session_id, to=session_id,
) )
for session_id in session_ids
]
await asyncio.gather(*emit_tasks)
if update_db:
if "type" in event_data and event_data["type"] == "status": if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id( Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"], request_info["chat_id"],
@ -299,6 +345,7 @@ def get_event_emitter(request_info):
request_info["message_id"], request_info["message_id"],
) )
if message:
content = message.get("content", "") content = message.get("content", "")
content += event_data.get("data", {}).get("content", "") content += event_data.get("data", {}).get("content", "")

View file

@ -1,15 +1,17 @@
import json import json
import redis
import uuid import uuid
from open_webui.utils.redis import get_redis_connection
class RedisLock: class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs): def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
self.lock_name = lock_name self.lock_name = lock_name
self.lock_id = str(uuid.uuid4()) self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs self.timeout_secs = timeout_secs
self.lock_obtained = False self.lock_obtained = False
self.redis = redis.Redis.from_url(redis_url, decode_responses=True) self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True
)
def aquire_lock(self): def aquire_lock(self):
# nx=True will only set this key if it _hasn't_ already been set # nx=True will only set this key if it _hasn't_ already been set
@ -31,9 +33,11 @@ class RedisLock:
class RedisDict: class RedisDict:
def __init__(self, name, redis_url): def __init__(self, name, redis_url, redis_sentinels=[]):
self.name = name self.name = name
self.redis = redis.Redis.from_url(redis_url, decode_responses=True) self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True
)
def __setitem__(self, key, value): def __setitem__(self, key, value):
serialized_value = json.dumps(value) serialized_value = json.dumps(value)

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

View file

@ -269,11 +269,6 @@ tbody + tbody {
margin-bottom: 0; margin-bottom: 0;
} }
/* Add a rule to reset margin-bottom for <p> not followed by <ul> */
.markdown-section p + ul {
margin-top: 0;
}
/* List item styles */ /* List item styles */
.markdown-section li { .markdown-section li {
padding: 2px; padding: 2px;

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Some files were not shown because too many files have changed in this diff Show more